diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index b9f438e81..75483caac 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -53,6 +53,7 @@ destroy_model_parallel, ) from vllm.transformers_utils.tokenizer import get_tokenizer + from vllm.lora.request import LoRARequest from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM logging.getLogger("vllm").propagate = True @@ -121,6 +122,7 @@ class VLLMModelConfig(ModelConfig): Maximum number of sequences per iteration. Controls batch size at prefill stage. Defaults to 128. max_num_batched_tokens (PositiveInt): Maximum number of tokens per batch. Defaults to 2048. + lora_path (str | None): path to loara modules subfolder (str | None): Subfolder within the model repository. Defaults to None. is_async (bool): @@ -166,6 +168,7 @@ class VLLMModelConfig(ModelConfig): pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together. max_num_seqs: PositiveInt = 128 # maximum number of sequences per iteration; This variable and `max_num_batched_tokens` effectively control the batch size at prefill stage. See https://github.com/vllm-project/vllm/issues/2492 for detailed explaination. max_num_batched_tokens: PositiveInt = 2048 # maximum number of tokens per batch + lora_path: str | None = None # path to the LoRA modules subfolder: str | None = None is_async: bool = False # Whether to use the async version or sync version of the model override_chat_template: bool = None @@ -201,6 +204,12 @@ def __init__( self.precision = config.dtype self.pairwise_tokenization = config.pairwise_tokenization + + # enable LoRA if lora_path is provided + if config.lora_path is not None: + self.lora_request = LoRARequest("default", 1, config.lora_path) + else: + self.lora_request = None self.prompt_manager = PromptManager(self.use_chat_template, self.tokenizer, config.system_prompt) @@ -257,6 +266,8 @@ def _create_auto_model(self, config: VLLMModelConfig) -> Optional[LLM]: "max_model_len": self._max_length, "swap_space": 4, "seed": int(config.seed), + "enable_lora": config.lora_path is not None, + "seed": int(config.seed), "max_num_seqs": int(config.max_num_seqs), "max_num_batched_tokens": int(config.max_num_batched_tokens), } @@ -430,11 +441,13 @@ def _generate( sampling_params.detokenize = False if self.data_parallel_size > 1: - @ray.remote(num_gpus=self.tensor_parallel_size) def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, requests): llm = LLM(**model_args) - return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) + if self.lora_request is not None: + return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params, lora_request=self.lora_request) + else: + return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) # dispatch requests to all self.data_parallel_size workers, in interleaved fashion # interleaved important to balance context lengths across workers @@ -451,11 +464,19 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r if x is not None ] else: - outputs = self.model.generate( - prompt_token_ids=inputs, - sampling_params=sampling_params, - use_tqdm=True, - ) + if self.lora_request is not None: + outputs = self.model.generate( + prompt_token_ids=inputs, + sampling_params=sampling_params, + lora_request=self.lora_request, + use_tqdm=True, + ) + else: + outputs = self.model.generate( + prompt_token_ids=inputs, + sampling_params=sampling_params, + use_tqdm=True, + ) return outputs