Skip to content

Commit b5372e0

Browse files
authored
support returning stop_str in output (#3984)
1 parent e7cbc54 commit b5372e0

File tree

4 files changed

+16
-12
lines changed

4 files changed

+16
-12
lines changed

lmdeploy/messages.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class GenerationConfig:
110110
logits_processors: Optional[List[LogitsProcessor]] = None
111111
output_logits: Literal['all', 'generation'] = None
112112
output_last_hidden_state: Literal['all', 'generation'] = None
113+
include_stop_str_in_output: bool = False
113114

114115
# for disaggregation
115116
with_cache: bool = False

lmdeploy/serve/async_engine.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,6 @@ async def generate(
680680
step: int = 0,
681681
do_preprocess: bool = True,
682682
adapter_name: Optional[str] = None,
683-
skip_stop_tokens: bool = True,
684683
rewind_stop_tokens: bool = False,
685684
input_ids: Optional[List] = None,
686685
enable_thinking: Optional[bool] = None,
@@ -777,9 +776,8 @@ async def generate(
777776
def is_error(status):
778777
return status not in [ResponseType.SUCCESS, ResponseType.FINISH, ResponseType.CANCEL]
779778

780-
# used to skip / rewind stop words in interactive mode
781779
stop_ids = []
782-
if skip_stop_tokens and not gen_config.ignore_eos:
780+
if not gen_config.ignore_eos:
783781
stop_ids = gen_config.stop_token_ids or []
784782

785783
metrics_processor.increment_total_requests()
@@ -863,11 +861,15 @@ def is_error(status):
863861

864862
if not is_error(outputs.status):
865863
finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'
866-
# utf-8 char at the end means it's a potential unfinished
867-
# byte sequence
864+
# utf-8 char at the end means it's a potential unfinished byte sequence
868865
if not response.endswith('�'):
869866
# avoid returning the last response twice
870867
response = ''
868+
token_ids = []
869+
if gen_config.include_stop_str_in_output and finish_reason == 'stop':
870+
# return the eos token id (MUST be in a list) and its string
871+
token_ids = outputs.token_ids[-1:]
872+
response = self.tokenizer.decode(token_ids, skip_special_tokens=False)
871873
logger.info(f'session {session_id} finished, reason '
872874
f'"{finish_reason}", input_tokens '
873875
f'{len(input_ids)}, output_tokens {gen_len}')
@@ -876,7 +878,7 @@ def is_error(status):
876878
len(input_ids),
877879
gen_len,
878880
finish_reason,
879-
token_ids=[],
881+
token_ids=token_ids,
880882
cache_block_ids=outputs.cache_block_ids)
881883
# Update a session's sequence only when it is in finished status
882884
if outputs.status == ResponseType.FINISH:

lmdeploy/serve/openai/api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
410410
repetition_penalty=request.repetition_penalty,
411411
ignore_eos=request.ignore_eos,
412412
stop_words=request.stop,
413+
include_stop_str_in_output=request.include_stop_str_in_output,
413414
skip_special_tokens=request.skip_special_tokens,
414415
response_format=response_format,
415416
logits_processors=logits_processors,

lmdeploy/serve/openai/protocol.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,15 @@ class ChatCompletionRequest(BaseModel):
107107
"""Chat completion request."""
108108
model: str
109109

110-
messages: Union[str, List[Dict[str, Any]]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]]) # noqa
110+
messages: Union[str, List[Dict[str, Any]]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]])
111111
temperature: Optional[float] = 0.7
112112
top_p: Optional[float] = 1.0
113113
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
114-
tool_choice: Union[ToolChoice, Literal['auto', 'required', 'none']] = Field(default='auto',
115-
examples=['none']) # noqa
114+
tool_choice: Union[ToolChoice, Literal['auto', 'required', 'none']] = Field(default='auto', examples=['none'])
116115
logprobs: Optional[bool] = False
117116
top_logprobs: Optional[int] = None
118117
n: Optional[int] = 1
119-
logit_bias: Optional[Dict[str, float]] = Field(default=None, examples=[None]) # noqa
118+
logit_bias: Optional[Dict[str, float]] = Field(default=None, examples=[None])
120119
max_completion_tokens: Optional[int] = Field(
121120
default=None,
122121
examples=[None],
@@ -128,15 +127,15 @@ class ChatCompletionRequest(BaseModel):
128127
examples=[None],
129128
deprecated='max_tokens is deprecated in favor of the max_completion_tokens field',
130129
)
131-
stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None]) # noqa
130+
stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None])
132131

133132
stream: Optional[bool] = False
134133
stream_options: Optional[StreamOptions] = Field(default=None, examples=[None])
135134
presence_penalty: Optional[float] = 0.0
136135
frequency_penalty: Optional[float] = 0.0
137136
user: Optional[str] = None
138137
reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None
139-
response_format: Optional[ResponseFormat] = Field(default=None, examples=[None]) # noqa
138+
response_format: Optional[ResponseFormat] = Field(default=None, examples=[None])
140139
# additional argument of lmdeploy
141140
do_preprocess: Optional[bool] = True
142141
repetition_penalty: Optional[float] = 1.0
@@ -150,6 +149,7 @@ class ChatCompletionRequest(BaseModel):
150149
min_p: float = 0.0
151150
enable_thinking: Optional[bool] = None
152151
return_token_ids: Optional[bool] = False
152+
include_stop_str_in_output: Optional[bool] = False
153153

154154

155155
class FunctionCall(BaseModel):

0 commit comments

Comments
 (0)