Skip to content

Commit 0051af7

Browse files
committed
Merge branch 'main' into support-abort-request
2 parents 6ad84d9 + a6aa375 commit 0051af7

File tree

14 files changed

+168
-82
lines changed

14 files changed

+168
-82
lines changed

lmdeploy/pytorch/engine/guided_process.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
logger = logging.getLogger('lmdeploy')
1111

1212

13-
class GuidedDecodingMangager:
13+
class GuidedDecodingManager:
1414
processors = {}
1515

1616
def __init__(self, tokenizer: PreTrainedTokenizerBase, vocab_size: Optional[int]):
@@ -26,7 +26,8 @@ def get_processors(self, session_ctx: List[Dict[str, Any]],
2626
processors = {}
2727
for i, _format in enumerate(response_formats):
2828
if isinstance(_format, Dict) and _format.get('type', 'text') != 'text':
29-
if _format['type'] == 'json_schema':
29+
schema_type = _format['type']
30+
if schema_type == 'json_schema':
3031
schema = _format['json_schema']
3132
if isinstance(schema, Dict):
3233
for key in ['json_schema', 'schema']:
@@ -37,15 +38,17 @@ def get_processors(self, session_ctx: List[Dict[str, Any]],
3738
raise ValueError(f'Cannot parse schema {schema}. The schema must be '
3839
'either a dictionary or a string that contains the'
3940
' JSON Schema specification')
40-
elif _format['type'] == 'regex_schema':
41+
elif schema_type == 'regex_schema':
4142
schema = _format.get('regex_schema', '')
43+
elif schema_type == 'json_object':
44+
schema = '{"type" : "object", "additionalProperties": true}'
4245
else:
43-
raise ValueError(f"unsupported format type: {_format['type']}")
46+
raise ValueError(f'unsupported format type: {schema_type}')
4447

4548
session_id = session_ctx[i]['session_id']
4649
seq_id = session_ctx[i]['seq_id']
4750

48-
processors[i] = self.get_processor(session_id, seq_id, schema, _format['type'])
51+
processors[i] = self.get_processor(session_id, seq_id, schema, schema_type)
4952

5053
return processors
5154

@@ -63,7 +66,9 @@ def get_processor(self, session_id: int, seq_id: int, schema: str, type: str) ->
6366
assert isinstance(schema, dict)
6467
compiled = self.compiler.compile_json_schema(schema)
6568
elif type == 'regex_schema':
66-
compiled = self.compiler.compile_regex_grammar(schema)
69+
compiled = self.compiler.compile_regex(schema)
70+
elif type == 'json_object':
71+
compiled = self.compiler.compile_json_schema(schema)
6772
else:
6873
assert False, f'Do not support schema type {type}'
6974

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lmdeploy.messages import LogitsProcessor
99

1010
from ..messages import SchedulerSequence
11-
from .guided_process import GuidedDecodingMangager
11+
from .guided_process import GuidedDecodingManager
1212

1313

1414
def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor):
@@ -143,12 +143,10 @@ class FusedLogitsProcessor:
143143
def __init__(
144144
self,
145145
sampling_inputs: SamplingInputs,
146-
sampling_vocab_size: Optional[int] = None,
147146
logprobs_mode: Optional[str] = None,
148-
guided_decoding_manager: Optional[GuidedDecodingMangager] = None,
147+
guided_decoding_manager: Optional[GuidedDecodingManager] = None,
149148
):
150149
self.sampling_inputs: SamplingInputs = sampling_inputs
151-
self.sampling_vocab_size = sampling_vocab_size
152150
self.logprobs_mode = logprobs_mode
153151
self.guided_decoding_manager = guided_decoding_manager
154152
if sampling_inputs.session_to_cleanup:
@@ -266,9 +264,6 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
266264
offsets = sampling_inputs.random_offsets
267265
return _multinomial_sampling(softmax_scores, seeds, offsets, indices)
268266

269-
if self.sampling_vocab_size is not None and logits.size(1) > self.sampling_vocab_size:
270-
logits = logits[..., :self.sampling_vocab_size]
271-
272267
if sampling_inputs.max_top_k == 1:
273268
result = logits.argmax(-1)
274269
else:

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..utils import get_gpu_memory
3131
from ..weight_loader.model_weight_loader import load_model_weights
3232
from .cache_engine import CacheEngine
33-
from .guided_process import GuidedDecodingMangager
33+
from .guided_process import GuidedDecodingManager
3434
from .logits_process import FusedLogitsProcessor, SamplingInputs
3535

3636
logger = get_logger('lmdeploy')
@@ -248,7 +248,8 @@ def model_forward(
248248
output = model(**input_dict)
249249

250250
# InternVL-3.5-Flash will change the seqlen, model_metas during forward
251-
model_metas = context.model_metas
251+
if context.model_metas is not None and context.model_metas[0] is not None:
252+
model_metas = context.model_metas
252253
seq_length = context.q_seqlens[:len(inputs.seq_length)]
253254

254255
return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length)
@@ -315,10 +316,6 @@ def __init__(self,
315316
self.cache_config = cache_config
316317
# use raw tokenizer
317318
self.tokenizer = Tokenizer(model_path).model.model
318-
try:
319-
self.sampling_vocab_size = len(self.tokenizer)
320-
except BaseException:
321-
self.sampling_vocab_size = None
322319

323320
self._pre_in_que = None
324321
self._in_que = None
@@ -354,9 +351,9 @@ def __init__(self,
354351
self.cache_engine = None
355352
self.profiler: AgentProfiler = None
356353
try:
357-
self.guided_decoding_manager = GuidedDecodingMangager(self.tokenizer, self.sampling_vocab_size)
354+
self.guided_decoding_manager = GuidedDecodingManager(self.tokenizer, model_config.vocab_size)
358355
except ValueError as e:
359-
logger.warning(f'Failed to create GuidedManager for tokenizer {self.tokenizer}: {e}')
356+
logger.warning(f'Failed to create GuidedManager for tokenizer {type(self.tokenizer)}: {e}')
360357
self.guided_decoding_manager = None
361358

362359
# microbatch
@@ -552,7 +549,6 @@ async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: Sam
552549
with record_function('sampling_logits'):
553550
logits_processor = FusedLogitsProcessor(
554551
sampling_inputs,
555-
sampling_vocab_size=self.sampling_vocab_size,
556552
logprobs_mode=self.misc_config.logprobs_mode,
557553
guided_decoding_manager=self.guided_decoding_manager,
558554
)

lmdeploy/serve/openai/api_server.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
591591
tool_calls = None
592592
reasoning_content = None
593593
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
594-
try: # TODO add json_schema guidance to turbomind
594+
try:
595595
tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
596596
text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
597597
if isinstance(tool_calls, List) and len(tool_calls):
@@ -907,6 +907,25 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
907907
return error_check_ret
908908
if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0:
909909
return create_error_response(HTTPStatus.BAD_REQUEST, f'The session_id `{request.session_id}` is occupied.')
910+
if (request.prompt is not None) ^ (request.input_ids is None):
911+
return create_error_response(HTTPStatus.BAD_REQUEST, 'You must specify exactly one of prompt or input_ids')
912+
913+
prompt = request.prompt
914+
input_ids = request.input_ids
915+
image_data = request.image_data
916+
if image_data is not None:
917+
# convert to openai format
918+
image_input = []
919+
if not isinstance(image_data, List):
920+
image_data = [image_data]
921+
for img in image_data:
922+
if isinstance(img, str):
923+
image_input.append(dict(type='image_url', image_url=dict(url=img)))
924+
else:
925+
image_input.append(dict(type='image_url', image_url=img))
926+
text_input = dict(type='text', text=prompt if prompt else input_ids)
927+
prompt = [dict(role='user', content=[text_input] + image_input)]
928+
input_ids = None
910929

911930
gen_config = GenerationConfig(
912931
max_new_tokens=request.max_tokens,
@@ -926,9 +945,9 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
926945
)
927946

928947
result_generator = VariableInterface.async_engine.generate(
929-
messages=request.prompt,
948+
messages=prompt,
930949
session_id=request.session_id,
931-
input_ids=request.input_ids,
950+
input_ids=input_ids,
932951
gen_config=gen_config,
933952
stream_response=True, # always use stream to enable batching
934953
sequence_start=True,

lmdeploy/serve/openai/protocol.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,11 +439,17 @@ class UpdateParamsRequest(BaseModel):
439439
finished: bool = False
440440

441441

442+
# str for url/base64, base64 should be data:image/jpeg;base64, dict should be {'url': url/base64, 'options': ...}
443+
ImageDataInputItem = Union[str, Dict]
444+
ImageDataFormat = Union[ImageDataInputItem, List[ImageDataInputItem]]
445+
446+
442447
# /generate input
443448
class GenerateReqInput(BaseModel):
444449
session_id: Optional[int] = -1
445450
prompt: Optional[str] = None
446451
input_ids: Optional[List[int]] = None
452+
image_data: Optional[ImageDataFormat] = None
447453
return_logprob: Optional[bool] = None
448454
max_tokens: int = 128
449455
stop: Optional[Union[str, List[str]]] = None

lmdeploy/turbomind/deploy/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ class ModelConfig:
5454
# Therefore, we add a new attr "embedding_size" to represent the vocab dim
5555
# of token_embedding
5656
embedding_size: int = 0
57-
# for some models like qwen2.5, the vocab size of the model is larger than
58-
# the vocab size of the tokenizer.
59-
tokenizer_size: int = None
6057
num_layer: int = None
6158
inter_size: List[int] = None
6259
norm_eps: float = None

lmdeploy/turbomind/deploy/target_model/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,6 @@ def update_model_config(self):
101101
final_cfg.update(self.input_model_info)
102102
if 'embedding_size' not in self.input_model_info.keys():
103103
final_cfg.update(embedding_size=self.input_model_info['vocab_size'])
104-
from transformers import AutoTokenizer
105-
tokenizer = AutoTokenizer.from_pretrained(self.input_model.tokenizer_path, trust_remote_code=True)
106-
tokenizer_size = min(len(tokenizer), final_cfg['vocab_size'])
107-
final_cfg.update(tokenizer_size=tokenizer_size)
108104

109105
self.model_config = config_from_dict(ModelConfig, final_cfg)
110106

lmdeploy/turbomind/turbomind.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,12 @@ async def async_stream_infer(self,
720720
try:
721721
tokenizer_info = TokenizerInfo.from_huggingface(tokenizer.model.model, vocab_size=vocab_size)
722722
decode_grammar_type = gen_config.response_format['type']
723-
decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
723+
if decode_grammar_type == 'json_schema':
724+
decode_grammar = gen_config.response_format[decode_grammar_type]['schema']
725+
elif decode_grammar_type == 'regex_schema':
726+
decode_grammar = gen_config.response_format[decode_grammar_type]
727+
elif decode_grammar_type == 'json_object':
728+
decode_grammar = '{"type" : "object", "additionalProperties": true}'
724729

725730
compiler = _xgr.GrammarCompiler(tokenizer_info)
726731

@@ -730,9 +735,12 @@ async def async_stream_infer(self,
730735
elif decode_grammar_type == 'regex_schema':
731736
decode_grammar = str(decode_grammar)
732737
grammar = compiler.compile_regex(decode_grammar)
738+
elif decode_grammar_type == 'json_object':
739+
decode_grammar = str(decode_grammar)
740+
grammar = compiler.compile_json_schema(decode_grammar)
733741
else:
734742
assert False, f'Decode grammar type {decode_grammar_type} should be in ' \
735-
'["json_schema", "regex_schema"]'
743+
'["json_schema", "regex_schema", "json_object"]'
736744

737745
self.model_inst.set_grammar(grammar)
738746
except ValueError as e:

lmdeploy/vl/engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,16 @@ async def wrap_for_pytorch(
8686
]
8787
)
8888
"""
89-
result = self.model.to_pytorch(messages,
90-
chat_template,
91-
tokenizer,
92-
sequence_start,
93-
tools=tools,
94-
enable_thinking=enable_thinking)
89+
has_input_ids = self.model.has_input_ids(messages)
90+
if not has_input_ids:
91+
result = self.model.to_pytorch(messages,
92+
chat_template,
93+
tokenizer,
94+
sequence_start,
95+
tools=tools,
96+
enable_thinking=enable_thinking)
97+
else:
98+
result = self.model.to_pytorch_with_input_ids(messages)
9599
# clear data
96100
for i, message in enumerate(messages):
97101
if isinstance(message['content'], List):

lmdeploy/vl/model/base.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from abc import ABC, abstractmethod
3+
from itertools import groupby
34
from typing import Dict, List, Union
45

56
import numpy as np
@@ -104,6 +105,18 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]:
104105
""" # noqa
105106
raise NotImplementedError()
106107

108+
def has_input_ids(self, messages: List[Dict]) -> bool:
109+
"""Check whether the messages contain input_ids directly.
110+
111+
Args:
112+
messages (List[Dict]): a list of message, which is supposed to be
113+
the output of `preprocess`
114+
Returns:
115+
bool: whether the messages contain input_ids directly
116+
"""
117+
users = [x['content'] for x in messages if x['role'] == 'user']
118+
return len(users) == 1 and isinstance(users[0], List) and isinstance(users[0][0].get('text', ''), List)
119+
107120
def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
108121
"""Extract image feature. ONLY implement it when the backend is
109122
turbomind engine.
@@ -168,6 +181,43 @@ def collect_images(messages):
168181
}) for x in content if x['type'] == 'image'])
169182
return images
170183

184+
def to_pytorch_with_input_ids(self, messages):
185+
"""Pack the preprocessing results in a format compatible with what is
186+
required by pytorch engine when input_ids are provided directly.
187+
188+
Args:
189+
messages(List[Dict]): the output of `preprocess`
190+
"""
191+
# collect all preprocessing result from messages
192+
preps = [x['content'] for x in messages if x['role'] == 'preprocess']
193+
assert len(preps) == 1
194+
preps = preps[0]
195+
196+
_input_ids = messages[0]['content'][0]['text']
197+
segs = []
198+
for k, g in groupby(_input_ids, lambda x: x == self.image_token_id):
199+
if not k:
200+
segs.append(list(g))
201+
else:
202+
segs.extend([[]] * (len(list(g)) - 1))
203+
if _input_ids[0] == self.image_token_id:
204+
segs = [[]] + segs
205+
if _input_ids[-1] == self.image_token_id:
206+
segs = segs + [[]]
207+
208+
assert self.image_token_id == preps[0]['image_token_id']
209+
assert len(segs) == len(preps) + 1, (f'the number of image token id {self.image_token_id} is not equal '
210+
f'to input images, {len(segs) - 1} vs {len(preps)}')
211+
input_ids = []
212+
for i, seg in enumerate(segs):
213+
if i > 0 and i <= len(preps):
214+
preps[i - 1].update(offset=len(input_ids))
215+
image_tokens = preps[i - 1]['image_tokens']
216+
input_ids.extend([self.image_token_id] * image_tokens)
217+
input_ids.extend(seg)
218+
219+
return dict(prompt=None, input_ids=input_ids, multimodal=preps)
220+
171221
def to_pytorch_aux(self, messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start):
172222
"""Auxiliary function to pack the preprocessing results in a format
173223
compatible with what is required by pytorch engine.

0 commit comments

Comments
 (0)