diff --git a/src/imitater/model/chat_model.py b/src/imitater/model/chat_model.py index cce1c25..b329bce 100644 --- a/src/imitater/model/chat_model.py +++ b/src/imitater/model/chat_model.py @@ -125,7 +125,7 @@ async def _generate(self, messages: List[Dict[str, str]], request_id: str, **gen sampling_params = SamplingParams( temperature=gen_kwargs.pop("temperature", self._generation_config.temperature), top_p=gen_kwargs.pop("top_p", self._generation_config.top_p), - max_tokens=gen_kwargs.pop("max_tokens", self._generation_config.max_new_tokens), + max_tokens=gen_kwargs.pop("max_tokens") or self._generation_config.max_new_tokens, stop=gen_kwargs.pop("stop", None), stop_token_ids=self._generation_config.eos_token_id + gen_kwargs.pop("stop_token_ids", []), )