diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6ddd2484ae1..a11e3a7cc50 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -580,6 +580,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "MllamaForConditionalGeneration", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5OmniModel", "KimiVLForConditionalGeneration", "InternVLChatModel", "Phi4MMForCausalLM", diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 661f47700e0..7b0c20cc024 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -573,9 +573,9 @@ def generate_chat_conv( num_image_url += 1 conv.modalities.append(content.modalities) image_token = ( - conv.image_token + "\n" - if conv.name != "qwen2-vl" - else conv.image_token + conv.image_token + if "qwen2" in conv.name + else conv.image_token + "\n" ) add_token_as_needed: bool = ( conv.name in _MODELS_REQUIRING_MODALITY_SUPPLEMENT @@ -795,6 +795,22 @@ def generate_chat_conv( ) ) +# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example +register_conv_template( + Conversation( + name="qwen2-5-o", + system_message="You are a helpful assistant.", + system_template="<|im_start|>system\n{system_message}", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep="<|im_end|>\n", + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + stop_str=["<|im_end|>"], + image_token="<|vision_bos|><|IMAGE|><|vision_eos|>", + # video_token="<|vision_bos|><|VIDEO|><|vision_eos|>", + audio_token="<|audio_bos|><|AUDIO|><|audio_eos|>", + ) +) + register_conv_template( Conversation( name="deepseek-vl2", @@ -955,6 +971,8 @@ def match_qwen_chat_ml(model_path: str): return "gme-qwen2-vl" if re.search(r"qwen.*vl", model_path, re.IGNORECASE): return "qwen2-vl" + if re.search(r"qwen2.5.*omni.*", model_path, re.IGNORECASE): + return "qwen2-5-o" if re.search( r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2", model_path, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index bd145a4b030..6fd394da01a 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1,14 +1,19 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py """Rotary Positional Embeddings.""" +import logging import math from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn +from transformers import PretrainedConfig + +logger = logging.getLogger(__name__) from sglang.srt.custom_op import CustomOp from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu +from sglang.utils import logger _is_cuda = is_cuda() _is_hip = is_hip() @@ -915,6 +920,505 @@ def forward( key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + # https://github.com/huggingface/transformers/blob/397a5ede33863d6f7137c771a68d40036cac0396/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py#L271 + @staticmethod + def get_rope_index_omni( + input_ids: Optional[torch.Tensor], + config: PretrainedConfig, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + use_audio_in_video: bool = False, + audio_seqlens: Optional[torch.LongTensor] = None, + second_per_grids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + image_token_id = config.image_token_index + video_token_id = config.video_token_index + audio_token_id = config.audio_token_index + vision_start_token_id = config.vision_start_token_id + audio_start_token_id = config.audio_start_token_id + position_id_per_seconds = config.position_id_per_seconds + seconds_per_chunk = config.seconds_per_chunk + spatial_merge_size = config.vision_config.spatial_merge_size + try: + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + attention_mask = torch.ones_like(total_input_ids).to("cuda") + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ).to("cuda") + image_idx, video_idx, audio_idx = 0, 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = ( + image_nums, + video_nums, + audio_nums, + ) + multimodal_nums = ( + image_nums + audio_nums + if use_audio_in_video + else image_nums + video_nums + audio_nums + ) + for _ in range(multimodal_nums): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio = input_tokens.index(audio_token_id, st) + else: + ed_audio = len(input_tokens) + 1 + min_ed = min(ed_image, ed_video, ed_audio) + if min_ed == ed_audio: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + audio_len = ( + (audio_seqlens[audio_idx] - 1) // 2 + 1 - 2 + ) // 2 + 1 + llm_pos_ids = ( + torch.arange(audio_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + + elif min_ed == ed_image: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t).to("cuda") + * 1 + * position_id_per_seconds + ).long() + llm_pos_ids = MRotaryEmbedding.get_llm_pos_ids_for_vision( + st_idx, + image_idx, + spatial_merge_size, + t_index, + grid_hs, + grid_ws, + ) + image_len = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2 + ) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + + elif min_ed == ed_video and not use_audio_in_video: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ).to("cuda") + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t).to("cuda") + * second_per_grids[video_idx].cpu().float() + * position_id_per_seconds + ).long() + llm_pos_ids = MRotaryEmbedding.get_llm_pos_ids_for_vision( + st_idx, + video_idx, + spatial_merge_size, + t_index, + grid_hs, + grid_ws, + ) + video_len = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + + elif min_ed == ed_video and use_audio_in_video: + text_len = min_ed - st - 2 + if text_len != 0: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + llm_pos_ids_list.append( + torch.arange(bos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + audio_len = ( + (audio_seqlens[audio_idx] - 1) // 2 + 1 - 2 + ) // 2 + 1 + audio_llm_pos_ids = ( + torch.arange(audio_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + ( + torch.arange(grid_t).to("cuda") + * second_per_grids[video_idx].cpu().float() + * position_id_per_seconds + ) + .long() + .to("cuda") + ) + video_llm_pos_ids = ( + MRotaryEmbedding.get_llm_pos_ids_for_vision( + st_idx, + video_idx, + spatial_merge_size, + t_index, + grid_hs, + grid_ws, + ) + ) + + t_ntoken_per_chunk = int( + position_id_per_seconds * seconds_per_chunk + ) + video_chunk_indexes = MRotaryEmbedding.get_chunked_index( + video_llm_pos_ids, t_ntoken_per_chunk, st_idx + ) + audio_chunk_indexes = MRotaryEmbedding.get_chunked_index( + audio_llm_pos_ids, t_ntoken_per_chunk, st_idx + ) + sub_len = 0 + for j in range( + max(len(video_chunk_indexes), len(audio_chunk_indexes)) + ): + video_chunk_index = ( + video_chunk_indexes[j] + if j < len(video_chunk_indexes) + else None + ) + audio_chunk_index = ( + audio_chunk_indexes[j] + if j < len(audio_chunk_indexes) + else None + ) + if video_chunk_index is not None: + sub_len += ( + video_chunk_index[1] - video_chunk_index[0] + ) + + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, + video_chunk_index[0] : video_chunk_index[1], + ] + ) + if audio_chunk_index is not None: + sub_len += ( + audio_chunk_index[1] - audio_chunk_index[0] + ) + + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, + audio_chunk_index[0] : audio_chunk_index[1], + ] + ) + video_len = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + llm_pos_ids_list.append( + torch.arange(eos_len) + .view(1, -1) + .expand(3, -1) + .to("cuda") + + st_idx + ) + + st += ( + text_len + + bos_len * 2 + + audio_len + + video_len + + eos_len * 2 + ) + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1).to("cuda") + + st_idx + ) + + llm_positions = ( + torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1).to("cuda") + ) + + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(input_ids) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + + return position_ids, mrope_position_deltas + else: + assert input_ids is not None, input_ids + # position_ids = attention_mask.long().cumsum(-1) - 1 + # position_ids.masked_fill_(attention_mask == 0, 1) + s = input_ids.shape[1] + position_ids = torch.arange(s) + position_ids = ( + position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - s + + return position_ids, mrope_position_deltas + except Exception as e: + logger.info(f"Please consider disabling chunked_prefill: {e}") + raise + # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439 @staticmethod def get_rope_index( @@ -1091,6 +1595,78 @@ def get_next_input_positions( ] ) + @staticmethod + def get_llm_pos_ids_for_vision( + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: List[int], + grid_hs: List[int], + grid_ws: List[int], + ): + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = ( + torch.arange(llm_grid_h, device=llm_grid_h.device) + .view(1, -1, 1) + .expand(len(t_index), -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w, device=llm_grid_w.device) + .view(1, 1, -1) + .expand(len(t_index), llm_grid_h, -1) + .flatten() + ) + t_index = ( + torch.Tensor(t_index) + .to(llm_grid_h.device) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + .long() + ) + _llm_pos_ids = torch.stack([t_index, h_index, w_index]).to("cuda") + llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + @staticmethod + def _pad_to_list_of_tensors_1d(tensor_list, padding_value=0, padding_side="left"): + lengths = [len(tensor) for tensor in tensor_list] + max_length = max(lengths) + pad_len = [max_length - leng for leng in lengths] + for idx in range(len(tensor_list)): + if pad_len[idx] != 0: + if padding_side == "left": + tensor_list[idx] = torch.cat( + [ + torch.full( + size=[pad_len[idx]], + fill_value=padding_value, + dtype=tensor_list[idx].dtype, + device=tensor_list[idx].device, + ), + tensor_list[idx], + ], + dim=0, + ) + else: + tensor_list[idx] = torch.cat( + [ + tensor_list[idx], + torch.full( + size=[pad_len[idx]], + fill_value=padding_value, + dtype=tensor_list[idx].dtype, + device=tensor_list[idx].device, + ), + ], + dim=0, + ) + return tensor_list + _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index 618f66a2fd3..97b5df9e98d 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -65,12 +65,25 @@ def convert_to_strs(self, processor): video_token_regex: Optional[re.Pattern] = None audio_token_regex: Optional[re.Pattern] = None - def __post_init__(self): - if self.image_token_regex is None and self.image_token is not None: + def compile_regex(self): + # TODO: move convert_to_strs to here, before compiling regex + if ( + self.image_token_regex is None + and self.image_token is not None + and isinstance(self.image_token, str) + ): self.image_token_regex = re.compile(re.escape(self.image_token)) - if self.video_token_regex is None and self.video_token is not None: + if ( + self.video_token_regex is None + and self.video_token is not None + and isinstance(self.video_token, str) + ): self.video_token_regex = re.compile(re.escape(self.video_token)) - if self.audio_token_regex is None and self.audio_token is not None: + if ( + self.audio_token_regex is None + and self.audio_token is not None + and isinstance(self.audio_token, str) + ): self.audio_token_regex = re.compile(re.escape(self.audio_token)) def collect(self) -> re.Pattern: @@ -216,34 +229,37 @@ def submit_data_loading_tasks( # Submit all tasks futures = [] task_info = [] - image_index, audio_index = 0, 0 + image_iter, audio_iter = None, None + if isinstance(image_data, list): + image_iter = iter(image_data) + if isinstance(audio_data, list): + audio_iter = iter(audio_data) for text_part in text_parts: if ( multimodal_tokens.image_token_regex and multimodal_tokens.image_token_regex.match(text_part) ): - data = image_data[image_index] + assert image_iter + data = next(image_iter) is_video = isinstance(data, str) and data.startswith("video:") - estimated_frames = estimated_frames_list[image_index] - frame_count_limit = max(1, int(estimated_frames * scaling_factor)) futures.append( self.io_executor.submit( BaseMultimodalProcessor._load_single_item, data, is_video, False, - frame_count_limit, + None, discard_alpha_channel, ) ) - task_info.append((Modality.IMAGE, data, frame_count_limit)) - image_index += 1 + task_info.append((Modality.IMAGE, data, None)) elif ( multimodal_tokens.audio_token_regex and multimodal_tokens.audio_token_regex.match(text_part) ): - data = audio_data[audio_index] + assert audio_iter + data = next(audio_iter) futures.append( self.io_executor.submit( BaseMultimodalProcessor._load_single_item, @@ -255,7 +271,6 @@ def submit_data_loading_tasks( ) ) task_info.append((Modality.AUDIO, data, None)) - audio_index += 1 return futures, task_info @@ -284,6 +299,8 @@ def load_mm_data( image_data = [] multimodal_tokens.convert_to_strs(self._processor) + # TODO: remove this + multimodal_tokens.compile_regex() multimodal_tokens_pattern = multimodal_tokens.collect() if isinstance(prompt, list) and return_text: @@ -361,6 +378,8 @@ def get_mm_items_offset( mm_token_id = 3 return result = [(2,4),(6,7)] """ + assert isinstance(mm_token_id, int), type(mm_token_id) + assert isinstance(input_ids, torch.Tensor), type(input_ids) mask = input_ids == mm_token_id start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0] diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py index d09b61b29d0..a5dfe16167f 100644 --- a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import math import re @@ -14,28 +16,46 @@ MultimodalSpecialTokens, ) from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.models.qwen2_5_omni import Qwen2_5OmniModel from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration -# Compatible with Qwen2VL and Qwen2_5VL +# Compatible with Qwen2VL, Qwen2_5VL and Qwen2_5_o class Qwen2_5VLImageProcessor(SGLangBaseProcessor): - models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration] + models = [ + Qwen2VLForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, + Qwen2_5OmniModel, + ] def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) # The single, pre-expanded image token. - self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" - # The regex that matches expanded image tokens. - self.IMAGE_TOKEN_REGEX = re.compile( - r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" - ) - self.IM_START_TOKEN_ID = hf_config.vision_start_token_id - self.IM_END_TOKEN_ID = hf_config.vision_end_token_id - self.IM_TOKEN_ID = hf_config.image_token_id - self.VIDEO_TOKEN_ID = hf_config.video_token_id - self.vision_start_token_id = hf_config.vision_start_token_id - self.vision_end_token_id = hf_config.vision_end_token_id + if self.arch == Qwen2_5OmniModel.__name__: + self.image_token_id = hf_config.thinker_config.image_token_index + self.image_start_id = hf_config.thinker_config.vision_start_token_id + self.image_end_id = hf_config.thinker_config.vision_end_token_id + self.audio_token_id = hf_config.thinker_config.audio_token_index + self.audio_start_id = hf_config.thinker_config.audio_start_token_id + self.audio_end_id = hf_config.thinker_config.audio_end_token_id + self.video_token_id = hf_config.thinker_config.video_token_index + # TODO: precomputed features might not need pre-processing anymore, try removing this + self.IMAGE_TOKEN_REGEX = re.compile( + r"<\|vision_bos\|>(?:<\|IMAGE\|>)+<\|vision_eos\|>" + ) + self.image_token = "<|vision_bos|><|IMAGE|><|vision_eo|>" + else: + self.image_token_id = hf_config.image_token_id + self.image_start_id = hf_config.vision_start_token_id + self.image_end_id = hf_config.vision_end_token_id + self.video_token_id = hf_config.video_token_id + # The regex that matches expanded image tokens. + self.IMAGE_TOKEN_REGEX = re.compile( + r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" + ) + self.image_token = "<|vision_start|><|image_pad|><|vision_end|>" + self.NUM_TOKEN_PER_FRAME = 770 self.IMAGE_FACTOR = 28 self.MIN_PIXELS = 4 * 28 * 28 @@ -57,9 +77,12 @@ async def process_mm_data_async( base_output = self.load_mm_data( prompt=input_text, image_data=image_data, + audio_data=request_obj.audio_data, multimodal_tokens=MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, + image_token=self.image_token, image_token_regex=self.IMAGE_TOKEN_REGEX, + audio_token=getattr(self, "audio_token_id", None), + video_token=getattr(self, "video_token_id", None), ), max_req_input_len=max_req_input_len, ) @@ -130,6 +153,18 @@ async def resize_image_async(image): resize_tasks = [resize_image_async(image) for image in base_output.images] base_output.images = await asyncio.gather(*resize_tasks) + ret = self.process_mm_data( + input_text=base_output.input_text, + images=None if images_are_preprocessed else base_output.images, + audio=base_output.audios, + ) + + input_ids = ret["input_ids"].flatten() + image_offsets = self.get_mm_items_offset( + input_ids=input_ids, mm_token_id=self.image_token_id + ) + + image_grid_thw = None video_grid_thw = None # TODO combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) @@ -141,28 +176,63 @@ async def resize_image_async(image): video_grid_thw = None # TODO second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None) - mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( - spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, - image_token_id=self.IM_TOKEN_ID, - video_token_id=self.VIDEO_TOKEN_ID, - vision_start_token_id=self.vision_start_token_id, - model_type=self.hf_config.model_type, - tokens_per_second=getattr( - self.hf_config.vision_config, "tokens_per_second", None - ), - input_ids=input_ids.unsqueeze(0), - image_grid_thw=combined_mm_item.image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - ) + if "input_features" in ret and ret["input_features"] is not None: + audio_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=getattr(self, "audio_token_id", None), + ) + item = MultimodalDataItem( + audio_features=ret["input_features"], + feature_attention_mask=ret["feature_attention_mask"], + attention_mask=ret["attention_mask"], + # TODO: unify feature and offsets across modalities + audio_offsets=audio_offsets, + modality=Modality.AUDIO, + ) + items += [item] + + if self.hf_config.model_type == "qwen2_5_omni": + feature_attention_mask = ret.get("feature_attention_mask", None) + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + else: + audio_feature_lengths = None + mrope_positions, mrope_position_delta = ( + MRotaryEmbedding.get_rope_index_omni( + input_ids=input_ids.unsqueeze(0), + config=self.hf_config.thinker_config, + image_grid_thw=ret.get("image_grid_thw", None), + video_grid_thw=ret.get("video_grid_thw", None), + audio_seqlens=audio_feature_lengths, + second_per_grids=ret.get("second_per_grids", None), + ) + ) + else: + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( + spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, + image_token_id=self.IM_TOKEN_ID, + video_token_id=self.VIDEO_TOKEN_ID, + vision_start_token_id=self.image_start_id, + model_type=self.hf_config.model_type, + tokens_per_second=getattr( + self.hf_config.vision_config, "tokens_per_second", None + ), + input_ids=input_ids.unsqueeze(0), + image_grid_thw=combined_mm_item.image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) mrope_positions = mrope_positions.squeeze(1) return { "input_ids": input_ids.tolist(), "mm_items": [combined_mm_item], - "im_start_id": self.IM_START_TOKEN_ID, - "im_end_id": self.IM_END_TOKEN_ID, + "im_start_id": self.image_start_id, + "im_end_id": self.image_end_id, "im_token_id": self.IM_TOKEN_ID, + "audio_start_id": getattr(self, "audio_start_id", None), + "audio_end_id": getattr(self, "audio_end_id", None), + "audio_token_id": getattr(self, "audio_token_id", None), "video_token_id": self.VIDEO_TOKEN_ID, "mrope_positions": mrope_positions, "mrope_position_delta": mrope_position_delta, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e197073406a..c35f4cdc205 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -214,6 +214,9 @@ class MultimodalDataItem: audio_feature_lens: Optional[List[torch.Tensor]] = None audio_offsets: Optional[List[Tuple[int, int]]] = None + attention_mask: Optional[torch.Tensor] = None + feature_attention_mask: Optional[torch.Tensor] = None + precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None @staticmethod diff --git a/python/sglang/srt/models/qwen2_5_omni.py b/python/sglang/srt/models/qwen2_5_omni.py new file mode 100644 index 00000000000..ca9f4139504 --- /dev/null +++ b/python/sglang/srt/models/qwen2_5_omni.py @@ -0,0 +1,724 @@ +# Copied and adapted from: https://github.com/huggingface/transformers/blob/5efaed689114030ffaf51c02f6f82adcbfc72389/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +from typing import Iterable, List, Optional, Set, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoder, +) +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionPatchEmbed, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLMLP, +) + +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2 import Qwen2Attention, Qwen2MLP +from sglang.srt.utils import add_prefix +from sglang.utils import logger + +############################ +# Start Thinker # +############################ + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(channels // 2) + ) + scaled_time = ( + torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + ) + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class Qwen2_5OmniVisionBlock(nn.Module): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.norm1 = RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = RMSNorm(config.hidden_size, eps=1e-6) + self.attn = VisionAttention( + embed_dim=config.hidden_size, + num_heads=config.num_heads, + projection_size=config.hidden_size, + use_qkv_parallel=True, + qkv_backend="sdpa", + softmax_in_single_precision=True, + flatten_batch=True, + quant_config=quant_config, + proj_bias=True, + prefix=add_prefix("attn", prefix), + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward(self, hidden_states, cu_seqlens, position_embeddings) -> torch.Tensor: + seq_len, _ = hidden_states.size() + + normed_hs = self.norm1(hidden_states) + normed_hs = normed_hs.unsqueeze(0) + attn = self.attn( + normed_hs, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings + ) + hidden_states = hidden_states + attn + hidden_states = hidden_states.view(seq_len, -1) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen2_5OmniPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen2_5OmniVisionEncoder(nn.Module): + _no_split_modules = ["Qwen2_5OmniVisionBlock"] + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.blocks = nn.ModuleList( + [ + Qwen2_5OmniVisionBlock(config, quant_config=quant_config) + for _ in range(config.depth) + ] + ) + self.merger = Qwen2_5OmniPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward( + self, hidden_states: torch.Tensor, grid_thw: torch.Tensor + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Modification here + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings, + ) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.gate_proj.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.gate_proj.weight.device + + +class Qwen2_5OmniDecoderLayer(nn.Module): + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + if ( + config.use_sliding_window + and config._attn_implementation != "flash_attention_2" + ): + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + self.rope_scaling = config.rope_scaling + hidden_size = config.hidden_size + head_num = config.num_attention_heads + kv_head_num = config.num_key_value_heads + head_dim = hidden_size // head_num + + self.q_size = head_num * head_dim + self.kv_size = kv_head_num * head_dim + + text_config = config + self.mlp = Qwen2MLP( + hidden_size=text_config.hidden_size, + intermediate_size=text_config.intermediate_size, + hidden_act=text_config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + rope_scaling = config.rope_scaling + rope_theta = getattr(config, "rope_theta", 1000000) + max_position_embeddings = getattr(config, "max_position_embeddings", 32768) + self.self_attn = Qwen2Attention( + hidden_size=hidden_size, + num_heads=head_num, + num_kv_heads=kv_head_num, + layer_id=layer_idx, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + + text_config = config + self.mlp = Qwen2MLP( + hidden_size=text_config.hidden_size, + intermediate_size=text_config.intermediate_size, + hidden_act=text_config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=position_ids, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen2_5OmniThinkerModel(nn.Module): + _no_split_modules = ["Qwen2_5OmniDecoderLayer"] + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + + self.layers = nn.ModuleList( + [ + Qwen2_5OmniDecoderLayer(config, layer_idx, quant_config=quant_config) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self): + return self.embed_tokens + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + + hidden_states = input_embeds + for layer_idx, decoder_layer in enumerate(self.layers): + layer_output = decoder_layer( + positions=positions, + position_ids=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = layer_output + + if hidden_states.dim() == 1: + hidden_states = hidden_states.unsqueeze(0) + + assert hidden_states.dim() == 2, hidden_states.shape + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class Qwen2_5OmniThinkerForConditionalGeneration(nn.Module): + _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"] + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.config = config + self.text_config = config.text_config + + self.audio_tower = Qwen2_5OmniAudioEncoder(config.audio_config) + + self.visual = Qwen2_5OmniVisionEncoder( + config.vision_config, quant_config=quant_config + ) + + self.vocab_size = config.text_config.vocab_size + self.model = Qwen2_5OmniThinkerModel( + config.text_config, quant_config=quant_config + ) + text_config = config.text_config + self.lm_head = ParallelLMHead( + self.vocab_size, + text_config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) + + self.is_mrope_enabled = "mrope_section" in self.config.text_config.rope_scaling + + self.logits_processor = LogitsProcessor(config.text_config) + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # in qwen-vl, last dim is the same + pixel_values = torch.concat([item.pixel_values for item in items], dim=0).type( + self.visual.dtype + ) + + image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0) + assert pixel_values.dim() == 2, pixel_values.dim() + assert image_grid_thws.dim() == 2, image_grid_thws.dim() + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws) + + return image_embeds + + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + input_features = ( + torch.cat([item.audio_features for item in items]) + .type(self.audio_tower.dtype) + .to(next(self.audio_tower.parameters()).device) + ) + feature_attention_mask = torch.cat( + [item.feature_attention_mask for item in items] + ) + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + else: + audio_feature_lengths = None + + feature_attention_mask = torch.sum(feature_attention_mask, dim=1) + + audio_feat_lengths, audio_output_lengths = ( + self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths + if audio_feature_lengths is not None + else feature_attention_mask.sum(-1) + ) + ) + feature_lens = ( + audio_feature_lengths + if audio_feature_lengths is not None + else feature_attention_mask.sum(-1) + ) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + + return audio_features + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + + if self.is_mrope_enabled: + positions = forward_batch.mrope_positions + hs = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.model, + image_data_embedding_func=self.get_image_feature, + audio_data_embedding_func=self.get_audio_feature, + positions=positions, + ) + return self.logits_processor(input_ids, hs, self.lm_head, forward_batch) + + +############################ +# Start Qwen2.5Omni # +############################ + + +class Qwen2_5OmniModel(nn.Module): + _no_split_modules = [ + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavModel", + ] + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.thinker = Qwen2_5OmniThinkerForConditionalGeneration( + config.thinker_config, + quant_config=quant_config, + prefix=add_prefix("thinker", prefix), + ) + self.has_talker = config.enable_audio_output + self.speaker_map = {} + + config.enable_audio_output = False + logger.info(f"Talker is not yet supported.") + if config.enable_audio_output: + self.enable_talker() + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + # Get all special token IDs + media_token_ids = [mm_inputs.im_token_id, mm_inputs.audio_token_id] + pattern = MultiModalityDataPaddingPatternMultimodalTokens(media_token_ids) + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def load_speakers(self, path): + for key, value in torch.load(path).items(): + self.speaker_map[key] = value + logger.info("Speaker {} loaded".format(list(self.speaker_map.keys()))) + + @classmethod + def can_generate(cls) -> bool: + return True + + @torch.no_grad() + def forward( + self, + input_ids: torch.tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ): + # 1. Generate from thinker module + thinker_result = self.thinker.forward( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + ) + + return thinker_result + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + # VisionAttention + (".qkv_proj.", ".q.", "q"), + (".qkv_proj.", ".k.", "k"), + (".qkv_proj.", ".v.", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + if "audio_tower" in name or "talker" in name: + continue + + if "visual" in name: + # mlp + if "gate_proj" in name or "up_proj" in name: + continue + ... + + name = name.replace(weight_name, param_name) + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + print(f"skipping {name}") + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + + if "talker" in name or "token2wav" in name: + continue + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + print(params_dict.keys()) + raise + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + logger.warn( + f"Some weights are not initialized from checkpoints: {sorted(unloaded_params)}" + ) + + +EntryClass = Qwen2_5OmniModel diff --git a/test/srt/test_omni_openai_server.py b/test/srt/test_omni_openai_server.py new file mode 100644 index 00000000000..044aecbea7e --- /dev/null +++ b/test/srt/test_omni_openai_server.py @@ -0,0 +1,147 @@ +""" +Usage: +python3 -m unittest test_omni_openai_server +""" + +import unittest + +from test_vision_openai_server_common import * + + +# Omni Models +class TestOpenAIOmniServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "openbmb/MiniCPM-o-2_6" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--mem-fraction-static", + "0.7", + ], + ) + cls.base_url += "/v1" + + def prepare_audio_messages(self, prompt, audio_file_name): + messages = [ + { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": f"{audio_file_name}"}, + }, + { + "type": "text", + "text": prompt, + }, + ], + } + ] + + return messages + + def get_audio_response(self, url: str, prompt, category): + audio_file_path = self.get_or_download_file(url) + client = openai.Client(api_key="sk-123456", base_url=self.base_url) + + messages = self.prepare_audio_messages(prompt, audio_file_path) + + response = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=128, + stream=False, + ) + + audio_response = response.choices[0].message.content + + print("-" * 30) + print(f"audio {category} response:\n{audio_response}") + print("-" * 30) + + audio_response = audio_response.lower() + + self.assertIsNotNone(audio_response) + self.assertGreater(len(audio_response), 0) + + return audio_response + + def verify_speech_recognition_response(self, text): + text = text.lower() + assert "thank you" in text + assert "it's a privilege to be here" in text + assert "leader" in text + assert "science" in text + assert "art" in text + + def _test_audio_speech_completion(self): + # a fragment of Trump's speech + audio_response = self.get_audio_response( + AUDIO_TRUMP_SPEECH_URL, + # "I have an audio sample. Please repeat the person's words", + "Repeat exactly what does the person say in the audio. Be exact", + category="speech", + ) + self.verify_speech_recognition_response(audio_response) + + def _test_audio_ambient_completion(self): + # bird song + audio_response = self.get_audio_response( + AUDIO_BIRD_SONG_URL, + "Please listen to the audio snippet carefully and transcribe the content.", + "ambient", + ) + assert "bird" in audio_response + + def test_audio_chat_completion(self): + self._test_audio_speech_completion() + self._test_audio_ambient_completion() + + def test_mixed_modality_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": IMAGE_MAN_IRONING_URL}, + }, + { + "type": "audio_url", + "audio_url": {"url": AUDIO_TRUMP_SPEECH_URL}, + }, + { + "type": "text", + "text": "I have an image and audio, which are not related at all. Please: 1. Describe the image in a sentence, 2. Repeat the exact words from the audio I provided. Be exact", + }, + ], + }, + ] + response = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=128, + stream=False, + ) + + text = response.choices[0].message.content + + print("-" * 30) + print(f"Mixed modality response:\n{text}") + print("-" * 30) + + self.verify_single_image_response(response=response) + self.verify_speech_recognition_response(text=text) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index a4a2e770dbf..c8bb3b71326 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -162,28 +162,5 @@ def setUpClass(cls): cls.base_url += "/v1" -class TestMinicpmoServer(TestOpenAIVisionServer): - @classmethod - def setUpClass(cls): - cls.model = "openbmb/MiniCPM-o-2_6" - cls.base_url = DEFAULT_URL_FOR_TEST - cls.api_key = "sk-123456" - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--mem-fraction-static", - "0.7", - ], - ) - cls.base_url += "/v1" - - def test_audio_chat_completion(self): - self._test_audio_speech_completion() - self._test_audio_ambient_completion() - - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 3687d9381d3..cd941372e91 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -1,5 +1,4 @@ import base64 -import copy import io import json import os @@ -48,6 +47,36 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) + def verify_single_image_response(self, response): + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + + print("-" * 30) + print(f"Single image response:\n{text}") + print("-" * 30) + + # `driver` is for gemma-3-it + assert ( + "man" in text or "person" or "driver" in text + ), f"text: {text}, should contain man, person or driver" + assert ( + "cab" in text + or "taxi" in text + or "SUV" in text + or "vehicle" in text + or "car" in text + ), f"text: {text}, should contain cab, taxi, SUV, vehicle or car" + # MiniCPMO fails to recognize `iron`, but `hanging` + assert ( + "iron" in text or "hang" in text or "cloth" in text or "holding" in text + ), f"text: {text}, should contain iron, hang, cloth or holding" + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + def get_request_kwargs(self): return {} @@ -66,38 +95,15 @@ def test_single_image_chat_completion(self): }, { "type": "text", - "text": "Describe this image in a very short sentence.", + "text": "Describe this image in a very short sentence more than 20 words.", }, ], }, ], temperature=0, - **(self.get_request_kwargs()), ) - assert response.choices[0].message.role == "assistant" - text = response.choices[0].message.content - assert isinstance(text, str) - # `driver` is for gemma-3-it - assert ( - "man" in text or "person" or "driver" in text - ), f"text: {text}, should contain man, person or driver" - assert ( - "cab" in text - or "taxi" in text - or "SUV" in text - or "vehicle" in text - or "car" in text - ), f"text: {text}, should contain cab, taxi, SUV, vehicle or car" - # MiniCPMO fails to recognize `iron`, but `hanging` - assert ( - "iron" in text or "hang" in text or "cloth" in text or "holding" in text - ), f"text: {text}, should contain iron, hang, cloth or holding" - assert response.id - assert response.created - assert response.usage.prompt_tokens > 0 - assert response.usage.completion_tokens > 0 - assert response.usage.total_tokens > 0 + self.verify_single_image_response(response=response) def test_multi_turn_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -201,7 +207,6 @@ def test_multi_images_chat_completion(self): def prepare_video_messages(self, video_path): # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa - # the size of the video embeds differs from the `modality` argument when preprocessed # We import decord here to avoid a strange Segmentation fault (core dumped) issue. # The following import order will cause Segmentation fault. @@ -209,7 +214,7 @@ def prepare_video_messages(self, video_path): # from transformers import AutoTokenizer from decord import VideoReader, cpu - max_frames_num = 20 + max_frames_num = 5 vr = VideoReader(video_path, ctx=cpu(0)) total_frame_num = len(vr) uniform_sampled_frames = np.linspace(