Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def add_parser_api_server():
ArgumentHelper.max_log_len(parser)
ArgumentHelper.disable_fastapi_docs(parser)
ArgumentHelper.allow_terminate_by_client(parser)
ArgumentHelper.enable_abort_handling(parser)
# chat template args
ArgumentHelper.chat_template(parser)

Expand Down Expand Up @@ -266,6 +267,7 @@ def api_server(args):
allow_methods=args.allow_methods,
allow_headers=args.allow_headers,
allow_terminate_by_client=args.allow_terminate_by_client,
enable_abort_handling=args.enable_abort_handling,
log_level=args.log_level.upper(),
api_keys=args.api_keys,
ssl=args.ssl,
Expand Down Expand Up @@ -293,6 +295,7 @@ def api_server(args):
allow_methods=args.allow_methods,
allow_headers=args.allow_headers,
allow_terminate_by_client=args.allow_terminate_by_client,
enable_abort_handling=args.enable_abort_handling,
log_level=args.log_level.upper(),
api_keys=args.api_keys,
ssl=args.ssl,
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,16 @@ def allow_terminate_by_client(parser):
default=False,
help='Enable server to be terminated by request from client')

@staticmethod
def enable_abort_handling(parser):
"""Add --enable-abort-handling argument to configure server abort
request processing."""

return parser.add_argument('--enable-abort-handling',
action='store_true',
default=False,
help='Enable server to handle client abort requests')

@staticmethod
def cache_max_entry_count(parser):
"""Add argument cache_max_entry_count to parser."""
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def _on_stop_session(self, reqs: List[Request], **kwargs):
for seq in session.sequences.values():
_resp: Response = getattr(seq, 'resp', None)
if _resp is not None:
_resp.type = ResponseType.FINISH
_resp.type = ResponseType.CANCEL
self.req_manager.response(_resp)
resp_type = ResponseType.SUCCESS
if resp:
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class EngineInstancePool:
def __init__(self, engine):
from lmdeploy.pytorch.engine import Engine
self.engine: Engine = engine
self.num_instance = self.engine.engine_config.max_batch_size
# enlarge `num_instance`, otherwise an sequence cannot be stopped in time
self.num_instance = self.engine.engine_config.max_batch_size * 2
self.pool = None

def create_instance_pool(self, num_instance: int):
Expand Down
24 changes: 21 additions & 3 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,26 @@ async def do_log_stats(self):
for stat_logger in self.stat_loggers:
stat_logger.log()

async def stop_all_session(self):
"""Stop all running sessions."""
logger.info('stop all sessions')
tasks = []
session_ids = []
for session_id in list(self.id2inst.keys()):
generator = self.id2inst.get(session_id)
if generator:
session_ids.append(session_id)
tasks.append(generator.async_cancel(session_id))
await asyncio.gather(*tasks)
logger.info(f'all {len(session_ids)} sessions stopped')

async def stop_session(self, session_id: int):
"""Stop a session by a session_id."""
logger.info(f'stop session {session_id}')
generator = self.id2inst.get(session_id)
if generator:
await generator.async_cancel(session_id)
logger.info(f'session {session_id} stopped')
# else it's not running at all

async def end_session(self, session_id: int):
Expand Down Expand Up @@ -855,7 +869,7 @@ def is_error(status):
break

output_len = len(outputs.token_ids)
if hit_stop_token:
if hit_stop_token or output_len == 0:
continue

# This assumes the engine will stop when stop token is hit
Expand Down Expand Up @@ -892,7 +906,11 @@ def is_error(status):
metrics_processor.increment_finished_requests()

if not is_error(outputs.status):
finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'
if outputs.status == ResponseType.CANCEL:
finish_reason = 'abort'
else:
finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'

# utf-8 char at the end means it's a potential unfinished byte sequence
if not response.endswith('�'):
# avoid returning the last response twice
Expand Down Expand Up @@ -926,7 +944,7 @@ def is_error(status):
output_len = gen_len
self.id2step[session_id] += input_len + output_len
else:
logger.error(f'session {session_id} finished, '
logger.error(f'session {session_id} finished, {outputs.status}, '
'reason "error"')
yield GenOut(response=f'internal error happened, status code {outputs.status}',
history_token_len=self.id2step[session_id],
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/openai/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_model_list(api_url: str, headers: dict = None):
logger = get_logger('lmdeploy')
if not response.ok:
logger.error(f'Failed to get the model list: {api_url}'
'returns {response.status_code}')
f' returns {response.status_code}')
return None
elif not hasattr(response, 'text'):
logger.warning('Failed to get the model list.')
Expand Down
20 changes: 19 additions & 1 deletion lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.serve.openai.harmony_utils import GptOssChatParser
from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501
from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponseChoice,
from lmdeploy.serve.openai.protocol import (AbortRequest, ChatCompletionRequest, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
ChatCompletionTokenLogprob, ChatMessage, ChoiceLogprobs, CompletionRequest,
CompletionResponse, CompletionResponseChoice,
Expand Down Expand Up @@ -65,6 +65,7 @@ class VariableInterface:
# following is for tool parsers
tool_parser: Optional[ToolParser] = None
allow_terminate_by_client: bool = False
enable_abort_handling: bool = False


router = APIRouter()
Expand Down Expand Up @@ -1149,6 +1150,21 @@ async def free_cache(cache_free_request: DistServeCacheFreeRequest) -> JSONRespo
""" PD Disaggregation API End """


@router.post('/abort_request')
async def abort_request(request: AbortRequest, raw_request: Request = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API does not like a safe API since malicurous might use this to stop sessions that does not belong to it. It is better to manually enable it and add safety warning here.

"""Abort an ongoing request."""
if not VariableInterface.enable_abort_handling:
return Response(
status_code=501,
content='This server does not support abort requests. Enable with --enable-abort-handling flag.')

if request.abort_all:
await VariableInterface.async_engine.stop_all_session()
else:
await VariableInterface.async_engine.stop_session(request.session_id)
return Response(status_code=200)


@router.post('/v1/chat/interactive', dependencies=[Depends(check_api_key)])
async def chat_interactive_v1(request: GenerateRequest, raw_request: Request = None):
return create_error_response(
Expand Down Expand Up @@ -1313,6 +1329,7 @@ def serve(model_path: str,
reasoning_parser: Optional[str] = None,
tool_call_parser: Optional[str] = None,
allow_terminate_by_client: bool = False,
enable_abort_handling: bool = False,
**kwargs):
"""An example to perform model inference through the command line
interface.
Expand Down Expand Up @@ -1371,6 +1388,7 @@ def serve(model_path: str,
logger.setLevel(log_level)

VariableInterface.allow_terminate_by_client = allow_terminate_by_client
VariableInterface.enable_abort_handling = enable_abort_handling
if api_keys is not None:
if isinstance(api_keys, str):
api_keys = api_keys.split(',')
Expand Down
12 changes: 11 additions & 1 deletion lmdeploy/serve/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class CompletionResponseStreamChoice(BaseModel):
text: str
logprobs: Optional[LogProbs] = None
gen_tokens: Optional[List[int]] = None
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error', 'abort']] = None


class CompletionStreamResponse(BaseModel):
Expand Down Expand Up @@ -472,3 +472,13 @@ class GenerateReqOutput(BaseModel):
text: str
output_ids: List[int]
meta_info: GenerateReqMetaOutput


class AbortRequest(BaseModel):
# Whether to abort all requests
abort_all: bool = False
# The finished reason data
finished_reason: Optional[Dict[str, Any]] = None
abort_message: Optional[str] = None
# The session ID to abort. If `abort_all` is True, this field is ignored.
session_id: Optional[int] = -1