diff --git a/requirements/common.txt b/requirements/common.txt index f2d1c0762..14d00e799 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -48,5 +48,6 @@ pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss +spnl >= 0.10.0 anthropic == 0.71.0 -model-hosting-container-standards < 1.0.0 \ No newline at end of file +model-hosting-container-standards < 1.0.0 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 70174250c..134866a01 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1384,6 +1384,154 @@ def load_log_config(log_config_file: str | None) -> dict | None: ) return None +if envs.VLLM_V1_SPANS_ENABLED: + import spnl + import time + from fastapi import Body + from vllm import SamplingParams + from vllm.inputs import TokensPrompt + from vllm.outputs import RequestOutput + from vllm.entrypoints.openai.protocol import (ChatMessage,ChatCompletionStreamResponse,ChatCompletionResponseStreamChoice,ChatCompletionResponseChoice,DeltaMessage,UsageInfo) + spnl_state = spnl.init(10) + PAD_TOKEN = 27 + PLUS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_PLUS if envs.VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None + CROSS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_CROSS if envs.VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None + def wrap(prompt: str | list[str]) -> TokensPrompt: + if isinstance(prompt[0], list): + return [TokensPrompt(prompt_token_ids=p) for p in prompt] + return TokensPrompt(prompt_token_ids=prompt) + @router.post("/v1/query/prepare") + @with_cancellation + @load_aware_call + async def prepare_query(raw_request: Request, + query: str = Body(..., media_type="text/plain")): + docs = [wrap(doc) for doc in spnl.tokenize_prepare( + spnl_state, + query, + True, # we need to preload the prefix of the plus/independent spans + PAD_TOKEN, + PLUS_TOKEN, + raw_request.app.state.vllm_config.cache_config.block_size + )] + + request_id = raw_request.headers.get( + "X-Request-Id") or uuid.uuid4().hex + client = engine_client(raw_request) + generators = [client.generate(doc, SamplingParams(temperature=0,max_tokens=1), request_id) for doc in docs] + for generator in generators: + async for res in generator: + final = res.outputs[0] + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.error.code) + return JSONResponse(content={"success": True}) + + @router.post("/v1/query/execute") + @with_cancellation + @load_aware_call + async def execute_query(raw_request: Request, + query: str = Body(..., media_type="text/plain"), + stream: bool = False): + req = spnl.tokenize_query( + spnl_state, + query, + PAD_TOKEN, + CROSS_TOKEN, + PLUS_TOKEN, + raw_request.app.state.vllm_config.cache_config.block_size + ) + + match req: + case spnl.TokenizedQuery.TokenizedChatCompletionQuery(q): + req = q # intentional fall-through + case spnl.TokenizedQuery.CompletionRequest(q): + request = CompletionRequest(model=q.model, max_tokens=q.max_tokens, temperature=q.temperature, prompt=q.inputs, stream=stream) + # what we want to do, but this is a fastapi endpoint... return create_completion(request, raw_request) + handler = completion(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Completions API") + + try: + generator = await handler.create_completion(request, raw_request) + except OverflowError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, + detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + detail=str(e)) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.error.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + request_id = raw_request.headers.get( + "X-Request-Id") or uuid.uuid4().hex + client = engine_client(raw_request) + generator = client.generate(wrap(req.messages), SamplingParams(n=1 if req.n <= 0 else req.n,temperature=req.temperature if req.temperature is not None else 0,max_tokens=req.max_tokens if req.max_tokens is not None and req.max_tokens != 0 else 2048), request_id) + + if stream: + async def sgen(): + output_idx: List[int] = [0 for _ in range(req.n)] + async for res in generator: + yield ChatCompletionStreamResponse( + id=request_id, + object="chat.completion.chunk", + created=int(time.time()), + model=req.model, + choices=[ + ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(role="assistant", content=output.text[output_idx[index]:]), + logprobs=output.logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + ) + for index, output in enumerate(res.outputs) + ] + ).model_dump_json(exclude_unset=True) + for index, output in enumerate(res.outputs): + output_idx[index] = len(output.text) + return StreamingResponse(content=sgen(), media_type="text/event-stream") + + outputs: List[Optional[CompletionOutput]] = [None for _ in range(req.n)] + async for res in generator: + for output in res.outputs: + outputs[output.index] = output + choices = [ + ChatCompletionResponseChoice( + index=index, + message=ChatMessage(role="assistant", content=output.text), + logprobs=output.logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + ) + for index, output in enumerate(outputs) + ] + num_prompt_tokens=0 # TODO + num_generated_tokens=0 # TODO + usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + + num_generated_tokens) + response = ChatCompletionResponse( + id=request_id, + created=int(time.time()), + model=req.model, + choices=choices, + usage=usage + ) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.error.code) + return JSONResponse(content=response.model_dump()) + class AuthenticationMiddleware: """ diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2012c3fef..354d9e17b 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -203,6 +203,8 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: request.block_hashes, max_cache_hit_length ) ) + if len(request.block_hashes) > 0: + print(f"vLLMCacheHitRate {100*(num_new_computed_tokens/(len(request.block_hashes)*self.block_size)):.2f}% computed={num_new_computed_tokens} requested={len(request.block_hashes)*self.block_size}", flush=True) if self.log_stats: assert self.prefix_cache_stats is not None