Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,25 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
return error_check_ret
if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0:
return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.')
if (request.prompt is not None) ^ (request.input_ids is None):
return create_error_response(HTTPStatus.BAD_REQUEST, 'You must specify exactly one of prompt or input_ids')

prompt = request.prompt
input_ids = request.input_ids
image_data = request.image_data
if image_data is not None:
# convert to openai format
image_input = []
if not isinstance(image_data, List):
image_data = [image_data]
for img in image_data:
if isinstance(img, str):
image_input.append(dict(type='image_url', image_url=dict(url=img)))
else:
image_input.append(dict(type='image_url', image_url=img))
text_input = dict(type='text', text=prompt if prompt else input_ids)
prompt = [dict(role='user', content=[text_input] + image_input)]
input_ids = None

gen_config = GenerationConfig(
max_new_tokens=request.max_tokens,
Expand All @@ -925,9 +944,9 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
)

result_generator = VariableInterface.async_engine.generate(
messages=request.prompt,
messages=prompt,
session_id=request.session_id,
input_ids=request.input_ids,
input_ids=input_ids,
gen_config=gen_config,
stream_response=True, # always use stream to enable batching
sequence_start=True,
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/serve/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,17 @@ class UpdateParamsRequest(BaseModel):
finished: bool = False


# str for url/base64, base64 should be data:image/jpeg;base64, dict should be {'url': url/base64, 'options': ...}
ImageDataInputItem = Union[str, Dict]
ImageDataFormat = Union[ImageDataInputItem, List[ImageDataInputItem]]


# /generate input
class GenerateReqInput(BaseModel):
session_id: Optional[int] = -1
prompt: Optional[str] = None
input_ids: Optional[List[int]] = None
image_data: Optional[ImageDataFormat] = None
return_logprob: Optional[bool] = None
max_tokens: int = 128
stop: Optional[Union[str, List[str]]] = None
Expand Down
16 changes: 10 additions & 6 deletions lmdeploy/vl/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,16 @@ async def wrap_for_pytorch(
]
)
"""
result = self.model.to_pytorch(messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
enable_thinking=enable_thinking)
has_input_ids = self.model.has_input_ids(messages)
if not has_input_ids:
result = self.model.to_pytorch(messages,
chat_template,
tokenizer,
sequence_start,
tools=tools,
enable_thinking=enable_thinking)
else:
result = self.model.to_pytorch_with_input_ids(messages)
# clear data
for i, message in enumerate(messages):
if isinstance(message['content'], List):
Expand Down
50 changes: 50 additions & 0 deletions lmdeploy/vl/model/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from itertools import groupby
from typing import Dict, List, Union

import numpy as np
Expand Down Expand Up @@ -104,6 +105,18 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
""" # noqa
raise NotImplementedError()

def has_input_ids(self, messages: List[Dict]) -> bool:
"""Check whether the messages contain input_ids directly.

Args:
messages (List[Dict]): a list of message, which is supposed to be
the output of `preprocess`
Returns:
bool: whether the messages contain input_ids directly
"""
users = [x['content'] for x in messages if x['role'] == 'user']
return len(users) == 1 and isinstance(users[0], List) and isinstance(users[0][0].get('text', ''), List)

def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
"""Extract image feature. ONLY implement it when the backend is
turbomind engine.
Expand Down Expand Up @@ -168,6 +181,43 @@ def collect_images(messages):
}) for x in content if x['type'] == 'image'])
return images

def to_pytorch_with_input_ids(self, messages):
"""Pack the preprocessing results in a format compatible with what is
required by pytorch engine when input_ids are provided directly.

Args:
messages(List[Dict]): the output of `preprocess`
"""
# collect all preprocessing result from messages
preps = [x['content'] for x in messages if x['role'] == 'preprocess']
assert len(preps) == 1
preps = preps[0]

_input_ids = messages[0]['content'][0]['text']
segs = []
for k, g in groupby(_input_ids, lambda x: x == self.image_token_id):
if not k:
segs.append(list(g))
else:
segs.extend([[]] * (len(list(g)) - 1))
if _input_ids[0] == self.image_token_id:
segs = [[]] + segs
if _input_ids[-1] == self.image_token_id:
segs = segs + [[]]

assert self.image_token_id == preps[0]['image_token_id']
assert len(segs) == len(preps) + 1, (f'the number of image token id {self.image_token_id} is not equal '
f'to input images, {len(segs) - 1} vs {len(preps)}')
input_ids = []
for i, seg in enumerate(segs):
if i > 0 and i <= len(preps):
preps[i - 1].update(offset=len(input_ids))
image_tokens = preps[i - 1]['image_tokens']
input_ids.extend([self.image_token_id] * image_tokens)
input_ids.extend(seg)

return dict(prompt=None, input_ids=input_ids, multimodal=preps)

def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):
"""Auxiliary function to pack the preprocessing results in a format
compatible with what is required by pytorch engine.
Expand Down