diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 3685094c9..23b6c1ab9 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -4,6 +4,7 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer import torch.multiprocessing as mp +from typing import TypedDict from nanovllm.config import Config from nanovllm.sampling_params import SamplingParams @@ -11,6 +12,10 @@ from nanovllm.engine.scheduler import Scheduler from nanovllm.engine.model_runner import ModelRunner +class GenerateOutput(TypedDict): + text: str + token_ids: list[int] + class LLMEngine: @@ -62,7 +67,7 @@ def generate( prompts: list[str] | list[list[int]], sampling_params: SamplingParams | list[SamplingParams], use_tqdm: bool = True, - ) -> list[str]: + ) -> list[GenerateOutput]: pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True, disable=not use_tqdm) if not isinstance(sampling_params, list): sampling_params = [sampling_params] * len(prompts)