File tree 1 file changed +12
-0
lines changed
1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -556,7 +556,19 @@ async def generate(self,
556
556
assert len (inputs ) == len (session_ids )
557
557
558
558
prompt = inputs
559
+ do_sample = kwargs .pop ('do_sample' , None )
559
560
gen_params = self .update_gen_params (** kwargs )
561
+ if do_sample is None :
562
+ do_sample = self .do_sample
563
+ if do_sample is not None and self .version < (0 , 6 , 0 ):
564
+ raise RuntimeError (
565
+ '`do_sample` parameter is not supported by lmdeploy until '
566
+ f'v0.6.0, but currently using lmdeloy { self .str_version } ' )
567
+ if self .version >= (0 , 6 , 0 ):
568
+ if do_sample is None :
569
+ do_sample = gen_params ['top_k' ] > 1 or gen_params [
570
+ 'temperature' ] > 0
571
+ gen_params .update (do_sample = do_sample )
560
572
gen_config = GenerationConfig (
561
573
skip_special_tokens = skip_special_tokens , ** gen_params )
562
574
You can’t perform that action at this time.
0 commit comments