Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@ pybase64 # fast base64 implementation
cbor2 # Required for cross-language serialization of hashable objects
setproctitle # Used to set process names for better debugging and monitoring
openai-harmony >= 0.0.3 # Required for gpt-oss
spnl >= 0.10.0
anthropic == 0.71.0
model-hosting-container-standards < 1.0.0
model-hosting-container-standards < 1.0.0
148 changes: 148 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,154 @@ def load_log_config(log_config_file: str | None) -> dict | None:
)
return None

if envs.VLLM_V1_SPANS_ENABLED:
import spnl
import time
from fastapi import Body
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.protocol import (ChatMessage,ChatCompletionStreamResponse,ChatCompletionResponseStreamChoice,ChatCompletionResponseChoice,DeltaMessage,UsageInfo)
spnl_state = spnl.init(10)
PAD_TOKEN = 27
PLUS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_PLUS if envs.VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None
CROSS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_CROSS if envs.VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None
def wrap(prompt: str | list[str]) -> TokensPrompt:
if isinstance(prompt[0], list):
return [TokensPrompt(prompt_token_ids=p) for p in prompt]
return TokensPrompt(prompt_token_ids=prompt)
@router.post("/v1/query/prepare")
@with_cancellation
@load_aware_call
async def prepare_query(raw_request: Request,
query: str = Body(..., media_type="text/plain")):
docs = [wrap(doc) for doc in spnl.tokenize_prepare(
spnl_state,
query,
True, # we need to preload the prefix of the plus/independent spans
PAD_TOKEN,
PLUS_TOKEN,
raw_request.app.state.vllm_config.cache_config.block_size
)]

request_id = raw_request.headers.get(
"X-Request-Id") or uuid.uuid4().hex
client = engine_client(raw_request)
generators = [client.generate(doc, SamplingParams(temperature=0,max_tokens=1), request_id) for doc in docs]
for generator in generators:
async for res in generator:
final = res.outputs[0]

if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.error.code)
return JSONResponse(content={"success": True})

@router.post("/v1/query/execute")
@with_cancellation
@load_aware_call
async def execute_query(raw_request: Request,
query: str = Body(..., media_type="text/plain"),
stream: bool = False):
req = spnl.tokenize_query(
spnl_state,
query,
PAD_TOKEN,
CROSS_TOKEN,
PLUS_TOKEN,
raw_request.app.state.vllm_config.cache_config.block_size
)

match req:
case spnl.TokenizedQuery.TokenizedChatCompletionQuery(q):
req = q # intentional fall-through
case spnl.TokenizedQuery.CompletionRequest(q):
request = CompletionRequest(model=q.model, max_tokens=q.max_tokens, temperature=q.temperature, prompt=q.inputs, stream=stream)
# what we want to do, but this is a fastapi endpoint... return create_completion(request, raw_request)
handler = completion(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API")

try:
generator = await handler.create_completion(request, raw_request)
except OverflowError as e:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value,
detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e

if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.error.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())

return StreamingResponse(content=generator, media_type="text/event-stream")

request_id = raw_request.headers.get(
"X-Request-Id") or uuid.uuid4().hex
client = engine_client(raw_request)
generator = client.generate(wrap(req.messages), SamplingParams(n=1 if req.n <= 0 else req.n,temperature=req.temperature if req.temperature is not None else 0,max_tokens=req.max_tokens if req.max_tokens is not None and req.max_tokens != 0 else 2048), request_id)

if stream:
async def sgen():
output_idx: List[int] = [0 for _ in range(req.n)]
async for res in generator:
yield ChatCompletionStreamResponse(
id=request_id,
object="chat.completion.chunk",
created=int(time.time()),
model=req.model,
choices=[
ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=output.text[output_idx[index]:]),
logprobs=output.logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
for index, output in enumerate(res.outputs)
]
).model_dump_json(exclude_unset=True)
for index, output in enumerate(res.outputs):
output_idx[index] = len(output.text)
return StreamingResponse(content=sgen(), media_type="text/event-stream")

outputs: List[Optional[CompletionOutput]] = [None for _ in range(req.n)]
async for res in generator:
for output in res.outputs:
outputs[output.index] = output
choices = [
ChatCompletionResponseChoice(
index=index,
message=ChatMessage(role="assistant", content=output.text),
logprobs=output.logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
for index, output in enumerate(outputs)
]
num_prompt_tokens=0 # TODO
num_generated_tokens=0 # TODO
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens +
num_generated_tokens)
response = ChatCompletionResponse(
id=request_id,
created=int(time.time()),
model=req.model,
choices=choices,
usage=usage
)

if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.error.code)
return JSONResponse(content=response.model_dump())


class AuthenticationMiddleware:
"""
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]:
request.block_hashes, max_cache_hit_length
)
)
if len(request.block_hashes) > 0:
print(f"vLLMCacheHitRate {100*(num_new_computed_tokens/(len(request.block_hashes)*self.block_size)):.2f}% computed={num_new_computed_tokens} requested={len(request.block_hashes)*self.block_size}", flush=True)

if self.log_stats:
assert self.prefix_cache_stats is not None
Expand Down