Skip to content

Commit c337aa8

Browse files
authored
Add do_sample in AsyncLMDeployPipeline (lmdeploy_wrapper.py) (#290)
1 parent a58c914 commit c337aa8

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

lagent/llms/lmdeploy_wrapper.py

+12
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,19 @@ async def generate(self,
556556
assert len(inputs) == len(session_ids)
557557

558558
prompt = inputs
559+
do_sample = kwargs.pop('do_sample', None)
559560
gen_params = self.update_gen_params(**kwargs)
561+
if do_sample is None:
562+
do_sample = self.do_sample
563+
if do_sample is not None and self.version < (0, 6, 0):
564+
raise RuntimeError(
565+
'`do_sample` parameter is not supported by lmdeploy until '
566+
f'v0.6.0, but currently using lmdeloy {self.str_version}')
567+
if self.version >= (0, 6, 0):
568+
if do_sample is None:
569+
do_sample = gen_params['top_k'] > 1 or gen_params[
570+
'temperature'] > 0
571+
gen_params.update(do_sample=do_sample)
560572
gen_config = GenerationConfig(
561573
skip_special_tokens=skip_special_tokens, **gen_params)
562574

0 commit comments

Comments
 (0)