diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 75472399fd..44fbe320fd 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -906,6 +906,25 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): return error_check_ret if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0: return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.') + if (request.prompt is not None) ^ (request.input_ids is None): + return create_error_response(HTTPStatus.BAD_REQUEST, 'You must specify exactly one of prompt or input_ids') + + prompt = request.prompt + input_ids = request.input_ids + image_data = request.image_data + if image_data is not None: + # convert to openai format + image_input = [] + if not isinstance(image_data, List): + image_data = [image_data] + for img in image_data: + if isinstance(img, str): + image_input.append(dict(type='image_url', image_url=dict(url=img))) + else: + image_input.append(dict(type='image_url', image_url=img)) + text_input = dict(type='text', text=prompt if prompt else input_ids) + prompt = [dict(role='user', content=[text_input] + image_input)] + input_ids = None gen_config = GenerationConfig( max_new_tokens=request.max_tokens, @@ -925,9 +944,9 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): ) result_generator = VariableInterface.async_engine.generate( - messages=request.prompt, + messages=prompt, session_id=request.session_id, - input_ids=request.input_ids, + input_ids=input_ids, gen_config=gen_config, stream_response=True, # always use stream to enable batching sequence_start=True, diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index e78579eded..66b866b4fc 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -439,11 +439,17 @@ class UpdateParamsRequest(BaseModel): finished: bool = False +# str for url/base64, base64 should be data:image/jpeg;base64, dict should be {'url': url/base64, 'options': ...} +ImageDataInputItem = Union[str, Dict] +ImageDataFormat = Union[ImageDataInputItem, List[ImageDataInputItem]] + + # /generate input class GenerateReqInput(BaseModel): session_id: Optional[int] = -1 prompt: Optional[str] = None input_ids: Optional[List[int]] = None + image_data: Optional[ImageDataFormat] = None return_logprob: Optional[bool] = None max_tokens: int = 128 stop: Optional[Union[str, List[str]]] = None diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py index 91a2e45be0..bfd4386eda 100644 --- a/lmdeploy/vl/engine.py +++ b/lmdeploy/vl/engine.py @@ -86,12 +86,16 @@ async def wrap_for_pytorch( ] ) """ - result = self.model.to_pytorch(messages, - chat_template, - tokenizer, - sequence_start, - tools=tools, - enable_thinking=enable_thinking) + has_input_ids = self.model.has_input_ids(messages) + if not has_input_ids: + result = self.model.to_pytorch(messages, + chat_template, + tokenizer, + sequence_start, + tools=tools, + enable_thinking=enable_thinking) + else: + result = self.model.to_pytorch_with_input_ids(messages) # clear data for i, message in enumerate(messages): if isinstance(message['content'], List): diff --git a/lmdeploy/vl/model/base.py b/lmdeploy/vl/model/base.py index dee926e26f..f06a175195 100644 --- a/lmdeploy/vl/model/base.py +++ b/lmdeploy/vl/model/base.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod +from itertools import groupby from typing import Dict, List, Union import numpy as np @@ -104,6 +105,18 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: """ # noqa raise NotImplementedError() + def has_input_ids(self, messages: List[Dict]) -> bool: + """Check whether the messages contain input_ids directly. + + Args: + messages (List[Dict]): a list of message, which is supposed to be + the output of `preprocess` + Returns: + bool: whether the messages contain input_ids directly + """ + users = [x['content'] for x in messages if x['role'] == 'user'] + return len(users) == 1 and isinstance(users[0], List) and isinstance(users[0][0].get('text', ''), List) + def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]: """Extract image feature. ONLY implement it when the backend is turbomind engine. @@ -168,6 +181,43 @@ def collect_images(messages): }) for x in content if x['type'] == 'image']) return images + def to_pytorch_with_input_ids(self, messages): + """Pack the preprocessing results in a format compatible with what is + required by pytorch engine when input_ids are provided directly. + + Args: + messages(List[Dict]): the output of `preprocess` + """ + # collect all preprocessing result from messages + preps = [x['content'] for x in messages if x['role'] == 'preprocess'] + assert len(preps) == 1 + preps = preps[0] + + _input_ids = messages[0]['content'][0]['text'] + segs = [] + for k, g in groupby(_input_ids, lambda x: x == self.image_token_id): + if not k: + segs.append(list(g)) + else: + segs.extend([[]] * (len(list(g)) - 1)) + if _input_ids[0] == self.image_token_id: + segs = [[]] + segs + if _input_ids[-1] == self.image_token_id: + segs = segs + [[]] + + assert self.image_token_id == preps[0]['image_token_id'] + assert len(segs) == len(preps) + 1, (f'the number of image token id {self.image_token_id} is not equal ' + f'to input images, {len(segs) - 1} vs {len(preps)}') + input_ids = [] + for i, seg in enumerate(segs): + if i > 0 and i <= len(preps): + preps[i - 1].update(offset=len(input_ids)) + image_tokens = preps[i - 1]['image_tokens'] + input_ids.extend([self.image_token_id] * image_tokens) + input_ids.extend(seg) + + return dict(prompt=None, input_ids=input_ids, multimodal=preps) + def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start): """Auxiliary function to pack the preprocessing results in a format compatible with what is required by pytorch engine.