diff --git a/src/lm_polygraph/model_adapters/whitebox_model_vllm.py b/src/lm_polygraph/model_adapters/whitebox_model_vllm.py index a3e4e71d8..f95e7fc23 100644 --- a/src/lm_polygraph/model_adapters/whitebox_model_vllm.py +++ b/src/lm_polygraph/model_adapters/whitebox_model_vllm.py @@ -4,6 +4,7 @@ import torch from typing import List +from copy import copy class WhiteboxModelvLLM(Model): @@ -46,8 +47,10 @@ def __init__( self.model_type = "vLLMCausalLM" def generate(self, *args, **kwargs): - sampling_params = self.sampling_params + sampling_params = copy(self.sampling_params) sampling_params.n = kwargs.get("num_return_sequences", 1) + if "max_new_tokens" in kwargs: + sampling_params.max_tokens = kwargs["max_new_tokens"] texts = self.tokenizer.batch_decode( kwargs["input_ids"], skip_special_tokens=True )