Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"MllamaForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"Qwen2_5OmniModel",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"Phi4MMForCausalLM",
Expand Down
24 changes: 21 additions & 3 deletions python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,9 @@ def generate_chat_conv(
num_image_url += 1
conv.modalities.append(content.modalities)
image_token = (
conv.image_token + "\n"
if conv.name != "qwen2-vl"
else conv.image_token
conv.image_token
if "qwen2" in conv.name
else conv.image_token + "\n"
Comment on lines +576 to +578
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for adding a newline to image_token has been inverted. While this seems to correctly handle the qwen2 family of models, this change is quite broad and could have unintended consequences for other models that are not in the qwen2 family. It would be safer to make this logic more specific to the models that require no newline, rather than making it the default for all qwen2 models and changing the behavior for all other models.

)
add_token_as_needed: bool = (
conv.name in _MODELS_REQUIRING_MODALITY_SUPPLEMENT
Expand Down Expand Up @@ -795,6 +795,22 @@ def generate_chat_conv(
)
)

# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
register_conv_template(
Conversation(
name="qwen2-5-o",
system_message="You are a helpful assistant.",
system_template="<|im_start|>system\n{system_message}",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=["<|im_end|>"],
image_token="<|vision_bos|><|IMAGE|><|vision_eos|>",
# video_token="<|vision_bos|><|VIDEO|><|vision_eos|>",
audio_token="<|audio_bos|><|AUDIO|><|audio_eos|>",
)
)

register_conv_template(
Conversation(
name="deepseek-vl2",
Expand Down Expand Up @@ -955,6 +971,8 @@ def match_qwen_chat_ml(model_path: str):
return "gme-qwen2-vl"
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
return "qwen2-vl"
if re.search(r"qwen2.5.*omni.*", model_path, re.IGNORECASE):
return "qwen2-5-o"
if re.search(
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
model_path,
Expand Down
576 changes: 576 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,25 @@ def convert_to_strs(self, processor):
video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None

def __post_init__(self):
if self.image_token_regex is None and self.image_token is not None:
def compile_regex(self):
# TODO: move convert_to_strs to here, before compiling regex
if (
self.image_token_regex is None
and self.image_token is not None
and isinstance(self.image_token, str)
):
self.image_token_regex = re.compile(re.escape(self.image_token))
if self.video_token_regex is None and self.video_token is not None:
if (
self.video_token_regex is None
and self.video_token is not None
and isinstance(self.video_token, str)
):
self.video_token_regex = re.compile(re.escape(self.video_token))
if self.audio_token_regex is None and self.audio_token is not None:
if (
self.audio_token_regex is None
and self.audio_token is not None
and isinstance(self.audio_token, str)
):
self.audio_token_regex = re.compile(re.escape(self.audio_token))

def collect(self) -> re.Pattern:
Expand Down Expand Up @@ -216,34 +229,37 @@ def submit_data_loading_tasks(
# Submit all tasks
futures = []
task_info = []
image_index, audio_index = 0, 0
image_iter, audio_iter = None, None
if isinstance(image_data, list):
image_iter = iter(image_data)
if isinstance(audio_data, list):
audio_iter = iter(audio_data)

for text_part in text_parts:
if (
multimodal_tokens.image_token_regex
and multimodal_tokens.image_token_regex.match(text_part)
):
data = image_data[image_index]
assert image_iter
data = next(image_iter)
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
is_video,
False,
frame_count_limit,
None,
discard_alpha_channel,
)
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
task_info.append((Modality.IMAGE, data, None))
Comment on lines 246 to +256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for frame_count_limit based on estimated_frames_list and scaling_factor has been removed, and _load_single_item is now called with None for the frame limit. This seems to disable frame limiting for videos in the base processor, which could be a significant breaking change affecting all models. If this is intentional, it should be documented. Otherwise, it might be a bug.

elif (
multimodal_tokens.audio_token_regex
and multimodal_tokens.audio_token_regex.match(text_part)
):
data = audio_data[audio_index]
assert audio_iter
data = next(audio_iter)
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
Expand All @@ -255,7 +271,6 @@ def submit_data_loading_tasks(
)
)
task_info.append((Modality.AUDIO, data, None))
audio_index += 1

return futures, task_info

Expand Down Expand Up @@ -284,6 +299,8 @@ def load_mm_data(
image_data = []

multimodal_tokens.convert_to_strs(self._processor)
# TODO: remove this
multimodal_tokens.compile_regex()
Comment on lines +302 to +303
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a TODO to remove the explicit call to compile_regex(). This suggests the current implementation is temporary. It would be cleaner to handle regex compilation within the MultimodalSpecialTokens class, for example in its __post_init__ or another initialization method, to avoid these explicit calls in the processor.

multimodal_tokens_pattern = multimodal_tokens.collect()

if isinstance(prompt, list) and return_text:
Expand Down Expand Up @@ -361,6 +378,8 @@ def get_mm_items_offset(
mm_token_id = 3
return result = [(2,4),(6,7)]
"""
assert isinstance(mm_token_id, int), type(mm_token_id)
assert isinstance(input_ids, torch.Tensor), type(input_ids)
mask = input_ids == mm_token_id

start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
Expand Down
130 changes: 100 additions & 30 deletions python/sglang/srt/managers/multimodal_processors/qwen_vl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import math
import re
Expand All @@ -14,28 +16,46 @@
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_5_omni import Qwen2_5OmniModel
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration


# Compatible with Qwen2VL and Qwen2_5VL
# Compatible with Qwen2VL, Qwen2_5VL and Qwen2_5_o
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
models = [
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen2_5OmniModel,
]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
# The single, pre-expanded image token.
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
)
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.IM_TOKEN_ID = hf_config.image_token_id
self.VIDEO_TOKEN_ID = hf_config.video_token_id
self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id
if self.arch == Qwen2_5OmniModel.__name__:
self.image_token_id = hf_config.thinker_config.image_token_index
self.image_start_id = hf_config.thinker_config.vision_start_token_id
self.image_end_id = hf_config.thinker_config.vision_end_token_id
self.audio_token_id = hf_config.thinker_config.audio_token_index
self.audio_start_id = hf_config.thinker_config.audio_start_token_id
self.audio_end_id = hf_config.thinker_config.audio_end_token_id
self.video_token_id = hf_config.thinker_config.video_token_index
# TODO: precomputed features might not need pre-processing anymore, try removing this
self.IMAGE_TOKEN_REGEX = re.compile(
r"<\|vision_bos\|>(?:<\|IMAGE\|>)+<\|vision_eos\|>"
)
self.image_token = "<|vision_bos|><|IMAGE|><|vision_eo|>"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a typo in the image_token. It's set to "<|vision_bos|><|IMAGE|><|vision_eo|>", but it should likely be "<|vision_bos|><|IMAGE|><|vision_eos|>" to match the conversation template. This will cause issues with tokenization and prompt formatting.

Suggested change
self.image_token = "<|vision_bos|><|IMAGE|><|vision_eo|>"
self.image_token = "<|vision_bos|><|IMAGE|><|vision_eos|>"

else:
self.image_token_id = hf_config.image_token_id
self.image_start_id = hf_config.vision_start_token_id
self.image_end_id = hf_config.vision_end_token_id
self.video_token_id = hf_config.video_token_id
# The regex that matches expanded image tokens.
self.IMAGE_TOKEN_REGEX = re.compile(
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
)
self.image_token = "<|vision_start|><|image_pad|><|vision_end|>"

self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28
Expand All @@ -57,9 +77,12 @@ async def process_mm_data_async(
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
audio_data=request_obj.audio_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token=self.image_token,
image_token_regex=self.IMAGE_TOKEN_REGEX,
audio_token=getattr(self, "audio_token_id", None),
video_token=getattr(self, "video_token_id", None),
),
max_req_input_len=max_req_input_len,
)
Expand Down Expand Up @@ -130,6 +153,18 @@ async def resize_image_async(image):
resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks)

ret = self.process_mm_data(
input_text=base_output.input_text,
images=None if images_are_preprocessed else base_output.images,
audio=base_output.audios,
)

input_ids = ret["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=self.image_token_id
)

image_grid_thw = None
video_grid_thw = None # TODO

combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
Comment on lines +156 to 170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method self.process_mm_data is called, and its return value ret is used later. However, self.process_and_combine_mm_data is also called, which internally calls self.process_mm_data again. This results in redundant processing and is inefficient. The logic should be refactored to avoid the duplicate call.

Expand All @@ -141,28 +176,63 @@ async def resize_image_async(image):
video_grid_thw = None # TODO
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)

mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.IM_TOKEN_ID,
video_token_id=self.VIDEO_TOKEN_ID,
vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type,
tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=input_ids.unsqueeze(0),
image_grid_thw=combined_mm_item.image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
)
if "input_features" in ret and ret["input_features"] is not None:
audio_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=getattr(self, "audio_token_id", None),
)
item = MultimodalDataItem(
audio_features=ret["input_features"],
feature_attention_mask=ret["feature_attention_mask"],
attention_mask=ret["attention_mask"],
# TODO: unify feature and offsets across modalities
audio_offsets=audio_offsets,
modality=Modality.AUDIO,
)
items += [item]

if self.hf_config.model_type == "qwen2_5_omni":
feature_attention_mask = ret.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
else:
audio_feature_lengths = None
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_rope_index_omni(
input_ids=input_ids.unsqueeze(0),
config=self.hf_config.thinker_config,
image_grid_thw=ret.get("image_grid_thw", None),
video_grid_thw=ret.get("video_grid_thw", None),
audio_seqlens=audio_feature_lengths,
second_per_grids=ret.get("second_per_grids", None),
)
)
else:
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.IM_TOKEN_ID,
video_token_id=self.VIDEO_TOKEN_ID,
vision_start_token_id=self.image_start_id,
model_type=self.hf_config.model_type,
tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=input_ids.unsqueeze(0),
image_grid_thw=combined_mm_item.image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
)
mrope_positions = mrope_positions.squeeze(1)

return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_start_id": self.image_start_id,
"im_end_id": self.image_end_id,
"im_token_id": self.IM_TOKEN_ID,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The returned dictionary uses self.IM_TOKEN_ID, but this attribute is not set for the Qwen2_5OmniModel architecture in the __init__ method. This will raise an AttributeError. You should use self.image_token_id instead, which is correctly initialized.

Suggested change
"im_token_id": self.IM_TOKEN_ID,
"im_token_id": self.image_token_id,

"audio_start_id": getattr(self, "audio_start_id", None),
"audio_end_id": getattr(self, "audio_end_id", None),
"audio_token_id": getattr(self, "audio_token_id", None),
"video_token_id": self.VIDEO_TOKEN_ID,
"mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta,
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ class MultimodalDataItem:
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None

attention_mask: Optional[torch.Tensor] = None
feature_attention_mask: Optional[torch.Tensor] = None

precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None

@staticmethod
Expand Down
Loading
Loading