diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 8cbbded9d7e..42c28bc3f17 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -853,6 +853,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal "Qwen2_5_VLForConditionalGeneration", "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration", + "Qwen3OmniMoeForConditionalGeneration", "KimiVLForConditionalGeneration", "InternVLChatModel", "InternS1ForConditionalGeneration", diff --git a/python/sglang/srt/configs/qwen3_omni.py b/python/sglang/srt/configs/qwen3_omni.py new file mode 100644 index 00000000000..d42e98a9a07 --- /dev/null +++ b/python/sglang/srt/configs/qwen3_omni.py @@ -0,0 +1,613 @@ +from transformers import PretrainedConfig +from transformers.configuration_utils import layer_type_validation +from transformers.modeling_rope_utils import rope_config_validation + +from sglang.utils import logger + + +class Qwen3OmniMoeAudioEncoderConfig(PretrainedConfig): + model_type = "qwen3_omni_moe_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + d_model=1280, + dropout=0, + attention_dropout=0, + activation_function="gelu", + activation_dropout=0, + scale_embedding=False, + initializer_range=0.02, + max_source_positions=1500, + n_window=100, + output_dim=3584, + n_window_infer=400, + conv_chunksize=500, + downsample_hidden_size=480, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.num_hidden_layers = encoder_layers + self.initializer_range = initializer_range + self.scale_embedding = ( + scale_embedding # scale factor will be sqrt(d_model) if True + ) + self.max_source_positions = max_source_positions + self.n_window = n_window + self.output_dim = output_dim + self.n_window_infer = n_window_infer + self.conv_chunksize = conv_chunksize + self.downsample_hidden_size = downsample_hidden_size + + +class Qwen3OmniMoeVisionEncoderConfig(PretrainedConfig): + model_type = "qwen3_omni_moe_vision_encoder" + base_config_key = "vision_config" + + def __init__( + self, + depth=27, + hidden_size=1152, + hidden_act="gelu_pytorch_tanh", + intermediate_size=4304, + num_heads=16, + in_channels=3, + patch_size=16, + spatial_merge_size=2, + temporal_patch_size=2, + out_hidden_size=3584, + num_position_embeddings=2304, + deepstack_visual_indexes=[8, 16, 24], + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.out_hidden_size = out_hidden_size + self.num_position_embeddings = num_position_embeddings + self.initializer_range = initializer_range + self.deepstack_visual_indexes = deepstack_visual_indexes + + +class Qwen3OmniMoeTextConfig(PretrainedConfig): + model_type = "qwen3_omni_moe_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3OmniMoeText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=3584, + hidden_size=2048, + intermediate_size=18944, + num_hidden_layers=28, + num_attention_heads=28, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + rope_scaling=None, + attention_bias=False, + sliding_window=None, + attention_dropout=0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=8, + num_experts=128, + norm_topk_prob=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + +class Qwen3OmniMoeThinkerConfig(PretrainedConfig): + model_type = "qwen3_omni_moe_thinker" + attribute_map = { + "image_token_id": "image_token_index", + "video_token_id": "video_token_index", + "audio_token_id": "audio_token_index", + } + sub_configs = { + "audio_config": Qwen3OmniMoeAudioEncoderConfig, + "vision_config": Qwen3OmniMoeVisionEncoderConfig, + "text_config": Qwen3OmniMoeTextConfig, + } + + def __init__( + self, + audio_config=None, + vision_config=None, + text_config=None, + audio_token_id=151646, + image_token_id=151655, + video_token_id=151656, + position_id_per_seconds=25, + audio_start_token_id=151647, + user_token_id=872, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.user_token_id = user_token_id + self.position_id_per_seconds = position_id_per_seconds + self.audio_start_token_id = audio_start_token_id + self.initializer_range = initializer_range + + if isinstance(vision_config, dict): + vision_config = Qwen3OmniMoeVisionEncoderConfig(**vision_config) + elif vision_config is None: + vision_config = Qwen3OmniMoeVisionEncoderConfig() + self.vision_config = vision_config + + if isinstance(audio_config, dict): + audio_config = Qwen3OmniMoeAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = Qwen3OmniMoeAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + text_config = Qwen3OmniMoeTextConfig(**text_config) + elif text_config is None: + text_config = Qwen3OmniMoeTextConfig() + self.text_config = text_config + self.audio_token_id = audio_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + + +class Qwen3OmniMoeTalkerCodePredictorConfig(PretrainedConfig): + + model_type = "qwen3_omni_moe_talker_code_predictor" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3OmniMoeTalkerCodePredictor` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=2048, + hidden_size=1024, + intermediate_size=3072, + num_hidden_layers=5, + num_attention_heads=16, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=0.000001, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000, + rope_scaling=None, + attention_bias=False, + sliding_window=None, + layer_types=None, + attention_dropout=0, + num_code_groups=32, + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + ( + "sliding_attention" + if self.sliding_window is not None and i >= self.max_window_layers + else "full_attention" + ) + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types, self.num_hidden_layers) + self.num_code_groups = num_code_groups + + +class Qwen3OmniMoeTalkerTextConfig(PretrainedConfig): + + model_type = "qwen3_omni_moe_talker_text" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3OmniMoeTalkerText` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.experts.*.gate_proj": "colwise", + "layers.*.mlp.experts.*.up_proj": "colwise", + "layers.*.mlp.experts.*.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=3072, + hidden_size=1024, + intermediate_size=2048, + num_hidden_layers=20, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=0.000001, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000, + rope_scaling=None, + attention_bias=False, + sliding_window=None, + attention_dropout=0, + decoder_sparse_step=1, + moe_intermediate_size=384, + num_experts_per_tok=8, + num_experts=128, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + **kwargs, + ): + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + +class Qwen3OmniMoeTalkerConfig(PretrainedConfig): + + sub_configs = { + "code_predictor_config": Qwen3OmniMoeTalkerCodePredictorConfig, + "text_config": Qwen3OmniMoeTalkerTextConfig, + } + + def __init__( + self, + code_predictor_config=None, + text_config=None, + num_code_groups=32, + thinker_hidden_size=2048, + codec_eos_token_id=4198, + accept_hidden_layer=18, + codec_nothink_id=4203, + codec_think_bos_id=4204, + codec_think_eos_id=4205, + codec_pad_id=4196, + codec_bos_id=4197, + audio_token_id=151646, + image_token_id=151655, + video_token_id=151656, + vision_start_token_id=151652, + position_id_per_seconds=25, + audio_start_token_id=151669, + speaker_id=None, + **kwargs, + ): + super().__init__(**kwargs) + if code_predictor_config is None: + code_predictor_config = {} + self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig() + logger.info( + "code_predictor_config is None. Initializing code_predictor_config model with default values" + ) + elif isinstance(code_predictor_config, Qwen3OmniMoeTalkerCodePredictorConfig): + self.code_predictor_config = code_predictor_config + else: + self.code_predictor_config = Qwen3OmniMoeTalkerCodePredictorConfig( + **code_predictor_config + ) + + if text_config is None: + text_config = {} + self.text_config = Qwen3OmniMoeTalkerTextConfig() + logger.info( + "talker text_config is None. Initializing talker text model with default values" + ) + elif isinstance(text_config, Qwen3OmniMoeTalkerTextConfig): + self.text_config = text_config + else: + self.text_config = Qwen3OmniMoeTalkerTextConfig(**text_config) + self.num_code_groups = num_code_groups + self.thinker_hidden_size = thinker_hidden_size + self.codec_eos_token_id = codec_eos_token_id + self.accept_hidden_layer = accept_hidden_layer + self.codec_nothink_id = codec_nothink_id + self.codec_think_bos_id = codec_think_bos_id + self.codec_think_eos_id = codec_think_eos_id + self.codec_pad_id = codec_pad_id + self.codec_bos_id = codec_bos_id + self.audio_token_id = audio_token_id + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.position_id_per_seconds = position_id_per_seconds + self.audio_start_token_id = audio_start_token_id + self.vision_start_token_id = vision_start_token_id + self.speaker_id = speaker_id + + +class Qwen3OmniMoeCode2WavConfig(PretrainedConfig): + + def __init__( + self, + codebook_size=2048, + hidden_size=1024, + max_position_embeddings=8000, + rope_theta=10000, + num_attention_heads=16, + num_key_value_heads=16, + attention_bias=False, + sliding_window=72, + intermediate_size=3072, + hidden_act="silu", + layer_scale_initial_scale=0.01, + rms_norm_eps=1e-5, + num_hidden_layers=8, + num_quantizers=16, + upsample_rates=(8, 5, 4, 3), + upsampling_ratios=(2, 2), + decoder_dim=1536, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.codebook_size = codebook_size + self.hidden_size = hidden_size + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.attention_bias = attention_bias + self.sliding_window = sliding_window + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.layer_scale_initial_scale = layer_scale_initial_scale + self.rms_norm_eps = rms_norm_eps + self.num_hidden_layers = num_hidden_layers + self.num_quantizers = num_quantizers + self.upsample_rates = upsample_rates + self.upsampling_ratios = upsampling_ratios + self.decoder_dim = decoder_dim + self.attention_dropout = attention_dropout + + @property + def layer_types(self): + """ + All layer in code2wav should be sliding attention + """ + return ["sliding_attention"] * self.num_hidden_layers + + +class Qwen3OmniMoeConfig(PretrainedConfig): + + model_type = "qwen3_omni_moe" + sub_configs = { + "thinker_config": Qwen3OmniMoeThinkerConfig, + "talker_config": Qwen3OmniMoeTalkerConfig, + "code2wav_config": Qwen3OmniMoeCode2WavConfig, + } + + def __init__( + self, + thinker_config=None, + talker_config=None, + code2wav_config=None, + enable_audio_output=True, + im_start_token_id=151644, + im_end_token_id=151645, + tts_pad_token_id=151671, + tts_bos_token_id=151672, + tts_eos_token_id=151673, + system_token_id=8948, + user_token_id=872, + assistant_token_id=77091, + **kwargs, + ): + super().__init__(**kwargs) + if thinker_config is None: + thinker_config = {} + logger.info( + "thinker_config is None. Initializing thinker model with default values" + ) + + if talker_config is None: + talker_config = {} + logger.info( + "talker_config is None. Initializing talker model with default values" + ) + + if code2wav_config is None: + code2wav_config = {} + logger.info( + "code2wav_config is None. Initializing code2wav model with default values" + ) + + self.thinker_config = Qwen3OmniMoeThinkerConfig(**thinker_config) + self.talker_config = Qwen3OmniMoeTalkerConfig(**talker_config) + self.code2wav_config = Qwen3OmniMoeCode2WavConfig(**code2wav_config) + self.enable_audio_output = enable_audio_output + self.im_start_token_id = im_start_token_id + self.im_end_token_id = im_end_token_id + self.tts_pad_token_id = tts_pad_token_id + self.tts_bos_token_id = tts_bos_token_id + self.tts_eos_token_id = tts_eos_token_id + self.system_token_id = system_token_id + self.user_token_id = user_token_id + self.assistant_token_id = assistant_token_id + + def get_text_config(self, decoder=False) -> "PretrainedConfig": + """ + Returns the config that is meant to be used with text IO. On most models, it is the original config instance + itself. On specific composite models, it is under a set of valid names. + + Args: + decoder (`Optional[bool]`, *optional*, defaults to `False`): + If set to `True`, then only search for decoder config names. + """ + # Overridden for deeply nested config like Qwen2-Omni. We don't have any omni model + # except for Qwen yet. This has to be generalized if more deeply nested configs are + # added. NOTE: currently method used only by vLLM + return self.thinker_config.get_text_config() diff --git a/python/sglang/srt/configs/qwen3_vl.py b/python/sglang/srt/configs/qwen3_vl.py index 4a995c856bc..a758d1f4e45 100644 --- a/python/sglang/srt/configs/qwen3_vl.py +++ b/python/sglang/srt/configs/qwen3_vl.py @@ -1,5 +1,3 @@ -from typing import Optional, Union - from transformers import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation @@ -576,11 +574,3 @@ def __init__( self.vision_start_token_id = vision_start_token_id self.vision_end_token_id = vision_end_token_id super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) - - -__all__ = [ - "Qwen3VLMoeConfig", - "Qwen3VLMoeVisionConfig", - "Qwen3VLConfig", - "Qwen3VLVisionConfig", -] diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 15b112539ff..e4ca62c8672 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1156,6 +1156,20 @@ def get_rope_index( second_per_grid_ts: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + if model_type == "qwen3_omni_moe": + # For qwen3-omni + return MRotaryEmbedding.get_rope_index_qwen3_omni( + spatial_merge_size, + image_token_id, + video_token_id, + vision_start_token_id, + tokens_per_second, + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + **kwargs, + ) if ( model_type.startswith("qwen3_vl") or model_type.startswith("qwen3_vl_moe") ) and video_grid_thw is not None: @@ -1163,6 +1177,7 @@ def get_rope_index( video_grid_thw, video_grid_thw[:, 0], dim=0 ) video_grid_thw[:, 0] = 1 + mrope_position_deltas = [] if input_ids is not None and ( image_grid_thw is not None or video_grid_thw is not None @@ -1248,7 +1263,11 @@ def get_rope_index( time_tensor_long = time_tensor.long() t_index = time_tensor_long.flatten() - elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"): + elif model_type in ( + "qwen2_vl", + "qwen3_vl", + "qwen3_vl_moe", + ): t_index = ( torch.arange(llm_grid_t) .view(-1, 1) @@ -1256,7 +1275,7 @@ def get_rope_index( .flatten() ) else: - raise RuntimeError("Unimplemented") + raise RuntimeError(f"Unimplemented model type: {model_type}") h_index = ( torch.arange(llm_grid_h) .view(1, -1, 1) @@ -1306,6 +1325,304 @@ def get_rope_index( mrope_position_deltas = max_position_ids + 1 - s return position_ids, mrope_position_deltas + @staticmethod + def get_rope_index_qwen3_omni( + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + tokens_per_second: Optional[int] = None, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # For qwen3-omni + audio_token_id = kwargs["audio_token_id"] + audio_start_token_id = kwargs["audio_start_token_id"] + position_id_per_seconds = kwargs["position_id_per_seconds"] + use_audio_in_video = kwargs.get("use_audio_in_video", False) + audio_seqlens = kwargs.get("audio_seqlens", None) + second_per_grids = second_per_grid_ts + + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + position_ids = torch.zeros( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=torch.float, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + for i, current_input_ids in enumerate(total_input_ids): + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere( + current_input_ids == vision_start_token_id + ).squeeze(1) + if vision_start_indices.numel() > 0: + vision_tokens = current_input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + audio_nums = torch.sum(current_input_ids == audio_start_token_id) + input_tokens = current_input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = ( + image_nums, + video_nums, + audio_nums, + ) + multimodal_nums = ( + image_nums + audio_nums + if use_audio_in_video + else image_nums + video_nums + audio_nums + ) + for _ in range(multimodal_nums): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + ed_vision_start = ( + input_tokens.index(vision_start_token_id, st) + if ( + ( + image_token_id in input_tokens + or video_token_id in input_tokens + ) + and (remain_videos > 0 or remain_images > 0) + ) + else len(input_tokens) + 1 + ) + ed_audio_start = ( + input_tokens.index(audio_start_token_id, st) + if (audio_token_id in input_tokens and remain_audios > 0) + else len(input_tokens) + 1 + ) + min_ed = min(ed_vision_start, ed_audio_start) + + text_len = min_ed - st + if text_len != 0: + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + st_idx += text_len + # Audio in Video + if ( + min_ed == ed_vision_start + and ed_vision_start + 1 == ed_audio_start + ): + bos_len, eos_len = 2, 2 + else: + bos_len, eos_len = 1, 1 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx + ) + st_idx += bos_len + # Audio Only + if min_ed == ed_audio_start: + audio_len = MRotaryEmbedding._get_feat_extract_output_lengths( + audio_seqlens[audio_idx] + ) + llm_pos_ids = ( + torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + ) + llm_pos_ids_list.append(llm_pos_ids) + + st += int(text_len + bos_len + audio_len + eos_len) + audio_idx += 1 + remain_audios -= 1 + + # Image Only + elif ( + min_ed == ed_vision_start + and current_input_ids[ed_vision_start + 1] == image_token_id + ): + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) * 1 * position_id_per_seconds + ).float() + llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision( + st_idx, + image_idx, + spatial_merge_size, + t_index, + grid_hs, + grid_ws, + input_ids.device, + ) + image_len = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2 + ) + llm_pos_ids_list.append(llm_pos_ids) + + st += int(text_len + bos_len + image_len + eos_len) + image_idx += 1 + remain_images -= 1 + + # Video Only + elif ( + min_ed == ed_vision_start + and current_input_ids[ed_vision_start + 1] == video_token_id + ): + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) + * second_per_grids[video_idx].cpu().float() + * position_id_per_seconds + ).float() + llm_pos_ids = MRotaryEmbedding._get_llm_pos_ids_for_vision( + st_idx, + video_idx, + spatial_merge_size, + t_index, + grid_hs, + grid_ws, + input_ids.device, + ) + video_len = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + llm_pos_ids_list.append(llm_pos_ids) + + st += int(text_len + bos_len + video_len + eos_len) + video_idx += 1 + remain_videos -= 1 + + # Audio in Video + elif ( + min_ed == ed_vision_start + and ed_vision_start + 1 == ed_audio_start + ): + audio_len = MRotaryEmbedding._get_feat_extract_output_lengths( + audio_seqlens[audio_idx] + ) + audio_llm_pos_ids = ( + torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + ) + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + torch.arange(grid_t) + * second_per_grids[video_idx].cpu().float() + * position_id_per_seconds + ).float() + video_llm_pos_ids = ( + MRotaryEmbedding._get_llm_pos_ids_for_vision( + st_idx, + video_idx, + spatial_merge_size, + t_index, + grid_hs, + grid_ws, + input_ids.device, + ) + ) + video_data_index, audio_data_index = 0, 0 + while ( + video_data_index < video_llm_pos_ids.shape[-1] + and audio_data_index < audio_llm_pos_ids.shape[-1] + ): + if ( + video_llm_pos_ids[0][video_data_index] + <= audio_llm_pos_ids[0][audio_data_index] + ): + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_data_index + 1 + ] + ) + video_data_index += 1 + else: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_data_index + 1 + ] + ) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_data_index : video_llm_pos_ids.shape[-1] + ] + ) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_data_index : audio_llm_pos_ids.shape[-1] + ] + ) + video_len = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + + st += int(text_len + bos_len + audio_len + video_len + eos_len) + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx + ) + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat( + [item.float() for item in llm_pos_ids_list], dim=1 + ).reshape(3, -1) + + position_ids[..., i, :] = llm_positions.to(position_ids.device) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(current_input_ids) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + + return position_ids, mrope_position_deltas + else: + s = input_ids.shape[1] + position_ids = torch.arange(s) + position_ids = ( + position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - s + + return position_ids, mrope_position_deltas + # Adapted from https://github.com/vllm-project/vllm/blob/3779eb8c81449b924a23457fc77e45a0e6171178/vllm/model_executor/layers/rotary_embedding.py#L1120 @staticmethod def get_rope_index_glm4v( @@ -1504,6 +1821,44 @@ def get_rope_index_glm4v( return position_ids, mrope_position_deltas + # For qwen3-omni + @staticmethod + def _get_feat_extract_output_lengths(input_lengths): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + ) + return output_lengths + + # For qwen3-omni + @staticmethod + def _get_llm_pos_ids_for_vision( + st_idx, vision_idx, spatial_merge_size, t_index, grid_hs, grid_ws, device + ): + grid_h = grid_hs[vision_idx] // spatial_merge_size + grid_w = grid_ws[vision_idx] // spatial_merge_size + + h_index = ( + torch.arange(grid_h, device=device) + .view(1, -1, 1) + .expand(len(t_index), -1, grid_w) + .flatten() + ) + w_index = ( + torch.arange(grid_w, device=device) + .view(1, 1, -1) + .expand(len(t_index), grid_h, -1) + .flatten() + ) + t_index = t_index.view(-1, 1).expand(-1, grid_h * grid_w).flatten() + + llm_pos_ids = torch.stack([t_index, h_index, w_index], dim=0) + st_idx + return llm_pos_ids + class DualChunkRotaryEmbedding(CustomOp): """Rotary positional embedding for Dual Chunk Attention.""" diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index e2012e9dea4..60283080bc0 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -280,7 +280,6 @@ def pad_input_tokens( input_ids_tensor[input_ids_tensor == token_id] = pad_value ret_input_ids = input_ids_tensor.tolist() - return ret_input_ids @@ -507,7 +506,7 @@ def embed_mm_inputs( Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] ] = None, placeholder_tokens: dict[Modality, List[int]] = None, - use_deepstack: bool = False, + use_deepstack: Dict[Modality, bool] = {}, ) -> Optional[torch.Tensor]: """ Embed multimodal inputs and integrate them with text token embeddings. @@ -533,7 +532,9 @@ def embed_mm_inputs( for mm_inputs in mm_inputs_list: item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] - embeddings, masks, deepstack_embeddings = [], [], [] + # deepstack_embeddings: per-modality + modalities, embeddings, masks, deepstack_embeddings = [], [], [], [] + # 2. Get multimodal embedding separately # Try get mm embedding if any for modality in Modality.all(): @@ -549,7 +550,8 @@ def embed_mm_inputs( # "image", "video", etc modality_id = modality.name.lower() embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None) - if len(items) != 0 and embedder is not None: + if len(items) != 0: + assert embedder is not None, f"no embedding method found for {modality}" placeholder_tensor = torch.as_tensor( [item.pad_value for item in items], device=input_ids.device, @@ -580,11 +582,12 @@ def embed_mm_inputs( items_offset_list=items_offsets, ) - if use_deepstack and embedding is not None: + if use_deepstack.get(modality, None) and embedding is not None: embedding, deepstack_embedding = ( multimodal_model.separate_deepstack_embeds(embedding) ) deepstack_embeddings += [deepstack_embedding] + modalities += [modality] embeddings += [embedding] masks += [mask] @@ -597,17 +600,14 @@ def embed_mm_inputs( input_ids.clamp_(min=0, max=vocab_size - 1) inputs_embeds = input_embedding(input_ids) - # 4. scatter embeddings into input embedding - # deepstack embedding if use_deepstack: - num_deepstack_embeddings = ( - len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0 - ) + num_deepstack_embeddings = len(multimodal_model.deepstack_visual_indexes) + deepstack_embedding_shape = inputs_embeds.shape[:-1] + ( inputs_embeds.shape[-1] * num_deepstack_embeddings, ) - + # a zero-filled embedding, with the same length of inputs_embeds, but different hidden_size input_deepstack_embeds = torch.zeros( deepstack_embedding_shape, device=inputs_embeds.device, @@ -616,14 +616,16 @@ def embed_mm_inputs( other_info["input_deepstack_embeds"] = input_deepstack_embeds - for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks): + # 4. scatter embeddings into input embedding + for i, modality, embedding, mask in zip( + range(len(embeddings)), modalities, embeddings, masks + ): if embedding is None or mask is None: continue # in-place update indices = torch.where(mask.squeeze(dim=-1))[0] inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) - - if use_deepstack: + if use_deepstack.get(modality, None): input_deepstack_embeds[indices] = deepstack_embeddings[i].to( inputs_embeds.device, inputs_embeds.dtype ) @@ -640,7 +642,7 @@ def general_mm_embed_routine( Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] ] = None, placeholder_tokens: Optional[dict[Modality, List[int]]] = None, - use_deepstack: bool = False, + use_deepstack: Dict[Modality, bool] = {}, **kwargs, ) -> torch.Tensor: """ @@ -652,7 +654,7 @@ def general_mm_embed_routine( language_model: Base language model to use data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. placeholder_tokens: Token IDs for multimodal placeholders - use_deepstack: Whether to use deepstack embeddings + use_deepstack: Whether to use deepstack embeddings for each modality, default False **kwargs: Additional arguments passed to language model Returns: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e646c2a6cdc..03c15fde952 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -587,9 +587,9 @@ async def _tokenize_one_request( ) if self.mm_processor and obj.contains_mm_input(): - if not isinstance(obj.image_data, list) and obj.image_data: + if obj.image_data is not None and not isinstance(obj.image_data, list): obj.image_data = [obj.image_data] - if not isinstance(obj.audio_data, list) and obj.audio_data: + if obj.audio_data is not None and not isinstance(obj.audio_data, list): obj.audio_data = [obj.audio_data] mm_inputs: Dict = await self.mm_processor.process_mm_data_async( image_data=obj.image_data, diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 5d84a23afdd..05612fca07a 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -518,6 +518,7 @@ def __init__( ) -> None: super().__init__() self.config = config + self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.pp_group = get_pp_group() diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 9991eb96b7c..a3044bef958 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -661,13 +661,14 @@ def __init__( config: Qwen3MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + decoder_layer_type=Qwen3MoeDecoderLayer, ) -> None: alt_stream = torch.cuda.Stream() if _is_cuda else None super().__init__( config=config, quant_config=quant_config, prefix=prefix, - decoder_layer_type=Qwen3MoeDecoderLayer, + decoder_layer_type=decoder_layer_type, alt_stream=alt_stream, ) diff --git a/python/sglang/srt/models/qwen3_omni_moe.py b/python/sglang/srt/models/qwen3_omni_moe.py new file mode 100644 index 00000000000..805e5d7a258 --- /dev/null +++ b/python/sglang/srt/models/qwen3_omni_moe.py @@ -0,0 +1,661 @@ +# Copyright 2025 Qwen Team +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inference-only Qwen3-VL model compatible with HuggingFace weights.""" +import math +from typing import Iterable, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput + +from sglang.srt.configs.qwen3_omni import ( + Qwen3OmniMoeAudioEncoderConfig, + Qwen3OmniMoeThinkerConfig, + Qwen3OmniMoeVisionEncoderConfig, +) +from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.schedule_batch import MultimodalDataItem +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel +from sglang.srt.models.qwen3_vl_moe import ( + Qwen3MoeLLMModel, + Qwen3VLMoeForConditionalGeneration, + load_fused_expert_weights, +) +from sglang.srt.utils import add_prefix, logger + + +class Qwen3OmniMoeAudioEncoderLayer(nn.Module): + def __init__( + self, + config: Qwen3OmniMoeAudioEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + embed_dim = config.d_model + self.embed_dim = config.d_model + self.self_attn = VisionAttention( + embed_dim=embed_dim, + num_heads=config.encoder_attention_heads, + projection_size=embed_dim, + use_qkv_parallel=True, + rotary_embed="normal", + proj_bias=True, + qkv_backend="fa3", + softmax_in_single_precision=False, + flatten_batch=True, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + x=hidden_states, + cu_seqlens=cu_seqlens, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(channels // 2).float() + ) + scaled_time = ( + torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + ) + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +def _get_feat_extract_output_lengths(input_lengths): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + ) + return output_lengths + + +class Qwen3OmniMoeAudioEncoder(PreTrainedModel): + config: Qwen3OmniMoeAudioEncoderConfig + + def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.positional_embedding = SinusoidsPositionEmbedding( + self.max_source_positions, embed_dim + ) + self.layers = nn.ModuleList( + [ + Qwen3OmniMoeAudioEncoderLayer(config) + for _ in range(config.encoder_layers) + ] + ) + self.ln_post = nn.LayerNorm(config.d_model) + self.gradient_checkpointing = False + self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1) + self.conv2d2 = nn.Conv2d( + config.downsample_hidden_size, + config.downsample_hidden_size, + 3, + 2, + padding=1, + ) + self.conv2d3 = nn.Conv2d( + config.downsample_hidden_size, + config.downsample_hidden_size, + 3, + 2, + padding=1, + ) + self.conv_out = nn.Linear( + config.downsample_hidden_size + * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), + config.d_model, + bias=False, + ) + self.proj1 = nn.Linear(config.d_model, config.d_model) + self.act = ACT2FN[config.activation_function] + self.proj2 = nn.Linear(config.d_model, config.output_dim) + self.n_window_infer = self.config.n_window_infer + self.conv_chunksize = self.config.conv_chunksize + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + ): + r""" + feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length after cnn + """ + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths[chunk_lengths == 0] = self.n_window * 2 + + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence( + chunk_list, batch_first=True + ).transpose(1, 2) + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = nn.utils.rnn.pad_sequence( + [ + torch.ones(length, dtype=torch.bool, device=padded_feature.device) + for length in feature_lens_after_cnn + ], + batch_first=True, + ) + padded_feature = padded_feature.unsqueeze(1) + # Split to chunk to avoid OOM during convolution + padded_embeds = [] + for chunk in padded_feature.split(self.conv_chunksize, dim=0): + padded_embed = F.gelu(self.conv2d1(chunk)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + padded_embeds.append(padded_embed) + padded_embed = torch.cat(padded_embeds, dim=0) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out( + padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f) + ) + + positional_embedding = ( + self.positional_embedding.positional_embedding[: padded_embed.shape[1], :] + .unsqueeze(0) + .to(padded_embed.dtype) + ) + padded_embed = padded_embed + positional_embedding + hidden_states = padded_embed[padded_mask_after_cnn] + cu_chunk_lens = [0] + window_aftercnn = padded_mask_after_cnn.shape[-1] * ( + self.n_window_infer // (self.n_window * 2) + ) + for cnn_len in aftercnn_lens: + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum( + -1, dtype=torch.int32 + ) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.ln_post(hidden_states) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return BaseModelOutput(last_hidden_state=hidden_states) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +class Qwen3OmniMoeVisionPatchMerger(nn.Module): + + def __init__( + self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_postshuffle_norm=False, + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = RMSNorm( + self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6 + ) + self.mlp = nn.ModuleList( + [ + ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=add_prefix("mlp.0", prefix), + ), + nn.GELU(), + RowParallelLinear( + self.hidden_size, + dim, + bias=True, + quant_config=quant_config, + prefix=add_prefix("mlp.2", prefix), + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = ( + x.view(-1, self.hidden_size) + if self.use_postshuffle_norm + else x.view(-1, x.shape[-1]) + ) + hidden = self.ln_q(x).view(-1, self.hidden_size) + for layer in self.mlp: + if isinstance(hidden, tuple): + hidden = hidden[0] + hidden = layer(hidden) + + if isinstance(hidden, tuple): + hidden = hidden[0] + + return hidden + + +class Qwen3OmniMoeVisionEncoder(Qwen3VLMoeVisionModel): + config: Qwen3OmniMoeVisionEncoderConfig + + def __init__( + self, + config: Qwen3OmniMoeVisionEncoderConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = None, + **kwargs, + ): + super().__init__( + vision_config=config, + quant_config=quant_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + ) + + self.merger = Qwen3OmniMoeVisionPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + quant_config=quant_config, + use_postshuffle_norm=False, + prefix=add_prefix("merger", prefix), + ) + self.merger_list = nn.ModuleList( + [ + Qwen3OmniMoeVisionPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + use_postshuffle_norm=True, + quant_config=quant_config, + prefix=add_prefix("merger_list", prefix), + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + del self.deepstack_merger_list + + @property + def deepstack_merger_list(self): + return self.merger_list + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + +class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen3VLMoeForConditionalGeneration): + config: Qwen3OmniMoeThinkerConfig + + def __init__( + self, + config: Qwen3OmniMoeThinkerConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__( + config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel + ) + self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config) + self.visual = Qwen3OmniMoeVisionEncoder( + config.vision_config, + quant_config=quant_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + prefix=add_prefix("visual", prefix), + ) + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) + + def get_audio_feature(self, items: List[MultimodalDataItem]): + feature_attention_mask = torch.cat( + [item.feature_attention_mask for item in items], dim=0 + ).type(torch.long) + input_features = ( + torch.cat([item.feature for item in items]) + .type(self.audio_tower.dtype) + .to(next(self.audio_tower.parameters()).device) + ) + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + else: + audio_feature_lengths = None + + feature_lens = ( + audio_feature_lengths + if audio_feature_lengths is not None + else feature_attention_mask.sum(-1) + ) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + ) + audio_features = audio_outputs.last_hidden_state + + return audio_features + + +class Qwen3OmniMoeForConditionalGeneration(PreTrainedModel): + def __init__( + self, + config: Qwen3VLMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__(config) + self.config = config + + self.thinker = Qwen3OmniMoeThinkerForConditionalGeneration( + config.thinker_config, quant_config=quant_config, prefix=prefix + ) + self.enable_talker = False + self.pad_input_ids = self.thinker.pad_input_ids + self.forward = self.thinker.forward + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = ( + ".bias", + "_bias", + ".k_scale", + "_k_scale", + ".v_scale", + "_v_scale", + ".weight_scale", + "_weight_scale", + ".input_scale", + "_input_scale", + ) + + is_fused_expert = False + fused_expert_params_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + + num_experts = self.config.num_experts + + # Cache params_dict to avoid repeated expensive traversal of model parameters + if not hasattr(self, "_cached_params_dict"): + self._cached_params_dict = dict(self.named_parameters()) + params_dict = self._cached_params_dict + + for name, loaded_weight in weights: + name = name.replace(r"model.language_model.", r"model.") + + if ("talker" in name or "code2wav" in name) and not self.enable_talker: + continue + + name = name.replace(".self_attn.out_proj", ".self_attn.proj") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if "experts.gate_up_proj" in name or "experts.down_proj" in name: + is_fused_expert = True + expert_params_mapping = fused_expert_params_mapping + + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if "visual" in name: + continue + + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + # [TODO] Skip layers that are on other devices (check if sglang has a similar function) + # if is_pp_missing_parameter(name, self): + # continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Track if this is an expert weight to enable early skipping + is_expert_weight = False + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + if "visual" in name or "audio_tower" in name: + continue + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + if is_fused_expert: + loaded_weight = loaded_weight.transpose(-1, -2) # no bias + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[0], + "w1", + num_experts, + ) + load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight[1], + "w3", + num_experts, + ) + else: + load_fused_expert_weights( + name_mapped, + params_dict, + loaded_weight, + shard_id, + num_experts, + ) + else: + # Skip loading extra parameters for GPTQ/modelopt models. + if ( + name_mapped.endswith(ignore_suffixes) + and name_mapped not in params_dict + ): + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # # other available replicas. + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + ) + name = name_mapped + break + else: + if is_expert_weight: + # This is an expert weight but not mapped to this rank, skip all remaining processing + continue + if "visual" in name or "audio_tower" in name: + # adapt to VisionAttention + name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") + name = name.replace(r"model.visual.", r"visual.") + name = name.replace(r"attn.out_proj.", r"attn.proj.") + + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning( + f"Loaded weight with {name=} not found in params_dict" + ) + + +EntryClass = Qwen3OmniMoeForConditionalGeneration diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 0db6541d33f..c41eb040316 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -15,7 +15,7 @@ """Inference-only Qwen3-VL model compatible with HuggingFace weights.""" import logging from functools import lru_cache, partial -from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from typing import Callable, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -27,7 +27,11 @@ Qwen2_5_VisionRotaryEmbedding, ) -from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig +from sglang.srt.configs.qwen3_vl import ( + Qwen3VLConfig, + Qwen3VLTextConfig, + Qwen3VLVisionConfig, +) from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor @@ -38,16 +42,24 @@ MultiModalityDataPaddingPatternMultimodalTokens, general_mm_embed_routine, ) -from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + ForwardMode, + PPProxyTensors, +) from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs from sglang.srt.models.qwen3 import Qwen3Model from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) + # === Vision Encoder === # @@ -196,7 +208,7 @@ def forward( return x -class Qwen3_VisionPatchMerger(nn.Module): +class Qwen3VLMoeVisionPatchMerger(nn.Module): def __init__( self, @@ -246,7 +258,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -class Qwen3_VisionTransformer(nn.Module): +class Qwen3VLMoeVisionModel(nn.Module): def __init__( self, @@ -263,10 +275,10 @@ def __init__( self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_unit = self.spatial_merge_size**2 self.temporal_patch_size = vision_config.temporal_patch_size + # layer indexes of which layer's output should be deep-stacked self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config) self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) - norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) @@ -286,7 +298,7 @@ def __init__( for layer_idx in range(vision_config.depth) ] ) - self.merger = Qwen3_VisionPatchMerger( + self.merger = Qwen3VLMoeVisionPatchMerger( dim=vision_config.out_hidden_size, context_dim=self.hidden_size, norm_layer=norm_layer, @@ -297,7 +309,7 @@ def __init__( self.deepstack_merger_list = nn.ModuleList( [ - Qwen3_VisionPatchMerger( + Qwen3VLMoeVisionPatchMerger( dim=vision_config.out_hidden_size, context_dim=self.hidden_size, spatial_merge_size=self.spatial_merge_size, @@ -462,7 +474,6 @@ def forward( ] ) - # max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) x = x.unsqueeze(1) deepstack_feature_lists = [] @@ -604,37 +615,43 @@ def __init__( config: Qwen3VLConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + language_model_cls=Qwen3LLMModel, ) -> None: super().__init__() - self.config = config - self.visual = Qwen3_VisionTransformer( + self.visual = Qwen3VLMoeVisionModel( config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. quant_config=quant_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), prefix=add_prefix("visual", prefix), ) - self.model = Qwen3LLMModel( - config=config, + # TODO: make it more elegant + if language_model_cls is Qwen3LLMModel: + self.config: Qwen3VLConfig = config # for qwen3-vl + else: + self.config = config.text_config # for qwen3-omni + + self.model = language_model_cls( + config=self.config, quant_config=quant_config, prefix=add_prefix("model", prefix), ) - if config.tie_word_embeddings: + if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, + self.config.vocab_size, + self.config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling - self.logits_processor = LogitsProcessor(config) + self.logits_processor = LogitsProcessor(self.config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states @@ -642,10 +659,7 @@ def __init__( # deepstack self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes self.num_deepstack_embeddings = len(self.deepstack_visual_indexes) - - @property - def use_deepstack(self) -> bool: - return hasattr(self, "deepstack_visual_indexes") + self.use_deepstack = {Modality.IMAGE: True, Modality.VIDEO: True} def separate_deepstack_embeds(self, embedding): assert ( diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py index 507403adbc2..c4d56a25701 100644 --- a/python/sglang/srt/models/qwen3_vl_moe.py +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -14,29 +14,19 @@ # ============================================================================== """Inference-only Qwen3-VL model compatible with HuggingFace weights.""" import logging -from functools import lru_cache, partial -from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union +from functools import lru_cache +from typing import Iterable, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from transformers import BatchFeature -from transformers.activations import ACT2FN -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VisionRotaryEmbedding, -) -from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig +from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, - get_pp_group, get_tensor_model_parallel_rank, ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.mm_utils import general_mm_embed_routine @@ -44,11 +34,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen3_moe import Qwen3MoeModel -from sglang.srt.models.qwen3_vl import ( - Qwen3_VisionTransformer, - Qwen3VLForConditionalGeneration, -) -from sglang.srt.utils import add_prefix +from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration from sglang.srt.utils.hf_transformers_utils import get_processor logger = logging.getLogger(__name__) @@ -60,28 +46,16 @@ class Qwen3MoeLLMModel(Qwen3MoeModel): def __init__( self, *, - config: Qwen3VLMoeConfig, + config: Qwen3VLMoeTextConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__(config=config, quant_config=quant_config, prefix=prefix) - self.hidden_size = config.hidden_size def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens - def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - # in qwen-vl, last dim is the same - pixel_values = torch.cat([item.feature for item in items], dim=0).type( - self.visual.dtype - ) - image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) - assert pixel_values.dim() == 2, pixel_values.dim() - assert image_grid_thw.dim() == 2, image_grid_thw.dim() - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - return image_embeds - def forward( self, input_ids: torch.Tensor, @@ -120,7 +94,7 @@ def forward( ) # process deepstack - if input_deepstack_embeds is not None and layer_idx in range(3): + if input_deepstack_embeds is not None and layer_idx < 3: sep = self.hidden_size * layer_idx hidden_states.add_( input_deepstack_embeds[:, sep : sep + self.hidden_size] @@ -146,144 +120,56 @@ def forward( return hidden_states, aux_hidden_states -class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): - def __init__( - self, - *, - config: Qwen3VLMoeConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super(Qwen3VLForConditionalGeneration, self).__init__() - self.config = config - - self.visual = Qwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. - # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported. - quant_config=quant_config, - prefix=add_prefix("visual", prefix), - ) - - self.model = Qwen3MoeLLMModel( - config=config, - quant_config=quant_config, - prefix=add_prefix("model", prefix), - ) - - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("lm_head", prefix), +def load_fused_expert_weights( + name: str, + params_dict: dict, + loaded_weight: torch.Tensor, + shard_id: str, + num_experts: int, +): + param = params_dict[name] + # weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + weight_loader = param.weight_loader + ep_rank = get_tensor_model_parallel_rank() + ep_size = get_moe_expert_parallel_world_size() + if ep_size == 1: + for expert_id in range(num_experts): + curr_expert_weight = loaded_weight[expert_id] + weight_loader( + param, + curr_expert_weight, + name, + shard_id, + expert_id, ) - self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling - - self.logits_processor = LogitsProcessor(config) - self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - - # deepstack - self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes - self.num_deepstack_embeddings = len(self.deepstack_visual_indexes) - - @property - def use_deepstack(self) -> bool: - return hasattr(self, "deepstack_visual_indexes") - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - forward_batch: ForwardBatch, - get_embedding: bool = False, - ): - """Run forward pass for Qwen3-VL. - - Args: - input_ids: Flattened (concatenated) input_ids corresponding to a - batch. - positions: Flattened (concatenated) position ids corresponding to a - batch. - **NOTE**: If mrope is enabled (default setting for Qwen2-VL - opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - (Use input_metadata.mrope_positions to replace it) - """ - if self.is_mrope_enabled: - positions = forward_batch.mrope_positions - - if not ( - forward_batch.forward_mode.is_decode() - or not forward_batch.contains_image_inputs() - ): - if self.is_mrope_enabled: - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}" - ) - - hidden_states = general_mm_embed_routine( - input_ids=input_ids, - forward_batch=forward_batch, - language_model=self.model, - multimodal_model=self, - positions=positions, - use_deepstack=self.use_deepstack, + else: + experts_per_ep = num_experts // ep_size + start_expert = ep_rank * experts_per_ep + end_expert = ( + (ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts ) - if not get_embedding: - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + for idx, expert_id in enumerate(range(start_expert, end_expert)): + curr_expert_weight = loaded_weight[expert_id] + weight_loader( + param, + curr_expert_weight, + name, + shard_id, + idx, ) - else: - return self.pooler(hidden_states, forward_batch) + return True - def load_fused_expert_weights( + +class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + def __init__( self, - name: str, - params_dict: dict, - loaded_weight: torch.Tensor, - shard_id: str, - num_experts: int, + config: Qwen3VLMoeConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + language_model_cls=Qwen3MoeLLMModel, ): - param = params_dict[name] - # weight_loader = typing.cast(Callable[..., bool], param.weight_loader) - weight_loader = param.weight_loader - ep_rank = get_tensor_model_parallel_rank() - ep_size = get_moe_expert_parallel_world_size() - if ep_size == 1: - for expert_id in range(num_experts): - curr_expert_weight = loaded_weight[expert_id] - weight_loader( - param, - curr_expert_weight, - name, - shard_id, - expert_id, - ) - else: - experts_per_ep = num_experts // ep_size - start_expert = ep_rank * experts_per_ep - end_expert = ( - (ep_rank + 1) * experts_per_ep - if ep_rank != ep_size - 1 - else num_experts - ) - - for idx, expert_id in enumerate(range(start_expert, end_expert)): - curr_expert_weight = loaded_weight[expert_id] - weight_loader( - param, - curr_expert_weight, - name, - shard_id, - idx, - ) - return True + super().__init__(config, quant_config, prefix, language_model_cls) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -329,8 +215,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self._cached_params_dict = dict(self.named_parameters()) params_dict = self._cached_params_dict for name, loaded_weight in weights: - if "language_model" in name: - name = name.replace(r"model.language_model.", r"model.") + name = name.replace(r"model.language_model.", r"model.") for param_name, weight_name, shard_id in stacked_params_mapping: if "experts.gate_up_proj" in name or "experts.down_proj" in name: @@ -384,14 +269,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight = loaded_weight.transpose(-1, -2) # no bias if "experts.gate_up_proj" in name: loaded_weight = loaded_weight.chunk(2, dim=-2) - self.load_fused_expert_weights( + load_fused_expert_weights( name_mapped, params_dict, loaded_weight[0], "w1", num_experts, ) - self.load_fused_expert_weights( + load_fused_expert_weights( name_mapped, params_dict, loaded_weight[1], @@ -399,7 +284,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): num_experts, ) else: - self.load_fused_expert_weights( + load_fused_expert_weights( name_mapped, params_dict, loaded_weight, diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index ef076ae0931..91b8ada745f 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -155,7 +155,6 @@ def __init__( ): self.hf_config = hf_config self._processor = _processor - self.arch = hf_config.architectures[0] self.server_args = server_args self.transport_mode = transport_mode @@ -191,6 +190,7 @@ def __init__( "input_features": Modality.AUDIO, "input_features_mask": Modality.AUDIO, "audio_attention_mask": Modality.AUDIO, + "feature_attention_mask": Modality.AUDIO, # Video-related attributes "pixel_values_videos": Modality.VIDEO, "second_per_grid_ts": Modality.VIDEO, @@ -222,6 +222,7 @@ def process_mm_data( if self._processor.__class__.__name__ in { "Gemma3nProcessor", "Qwen2AudioProcessor", + "Qwen3OmniMoeProcessor", }: # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107 kwargs["audio"] = audios diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index ec5e574f434..b6b899ebdc5 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -12,6 +12,7 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration +from sglang.srt.models.qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration from sglang.srt.models.qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import ( @@ -209,22 +210,31 @@ async def preprocess_video( return video -# Compatible with Qwen2VL and Qwen2_5VL -class Qwen2_5VLImageProcessor(SGLangBaseProcessor): +# Compatible with Qwen-VL & Qwen-Omni Series +class QwenVLImageProcessor(SGLangBaseProcessor): models = [ Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration, + Qwen3OmniMoeForConditionalGeneration, ] def __init__(self, hf_config, server_args, _processor, *args, **kwargs): + self.model_type = hf_config.model_type + if hf_config.model_type == "qwen3_omni_moe": + hf_config = hf_config.thinker_config + super().__init__(hf_config, server_args, _processor, *args, **kwargs) - # The regex that matches expanded image tokens. + self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = hf_config.vision_end_token_id + + self.audio_start_token_id = getattr(hf_config, "audio_start_token_id", None) + self.audio_token_id = getattr(hf_config, "audio_token_id", None) + self.NUM_TOKEN_PER_FRAME = 770 self.IMAGE_FACTOR = 28 self.MIN_PIXELS = 4 * 28 * 28 @@ -233,10 +243,12 @@ def __init__(self, hf_config, server_args, _processor, *args, **kwargs): self.mm_tokens = MultimodalSpecialTokens( image_token="<|vision_start|><|image_pad|><|vision_end|>", image_token_id=hf_config.image_token_id, + # The regex that matches expanded image tokens. image_token_regex=re.compile( r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" ), video_token_id=hf_config.video_token_id, + audio_token_id=self.audio_token_id, ).build(_processor) async def process_mm_data_async( @@ -247,11 +259,11 @@ async def process_mm_data_async( *args, **kwargs, ): - base_output = self.load_mm_data( prompt=input_text, image_data=image_data, video_data=request_obj.video_data, + audio_data=request_obj.audio_data, multimodal_tokens=self.mm_tokens, ) @@ -269,20 +281,41 @@ async def process_mm_data_async( base_output, self.mm_tokens ) + audio_feature_lengths = None + + if self.model_type == "qwen3_omni_moe": + audio_item = next((mm for mm in mm_items if mm.is_audio()), None) + if audio_item: + audio_feature_lengths = torch.sum( + audio_item.feature_attention_mask, dim=1 + ) + + second_per_grid_ts = getattr(ret, "second_per_grid_ts", None) or getattr( + ret, "video_second_per_grid", None + ) + input_ids = input_ids.flatten() + mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, image_token_id=self.mm_tokens.image_token_id, video_token_id=self.mm_tokens.video_token_id, vision_start_token_id=self.vision_start_token_id, - model_type=self.hf_config.model_type, + model_type=self.model_type, tokens_per_second=getattr( self.hf_config.vision_config, "tokens_per_second", None ), input_ids=input_ids.unsqueeze(0), image_grid_thw=getattr(ret, "image_grid_thw", None), video_grid_thw=getattr(ret, "video_grid_thw", None), - second_per_grid_ts=getattr(ret, "second_per_grid_ts", None), + second_per_grid_ts=second_per_grid_ts, + use_audio_in_video=False, + audio_seqlens=audio_feature_lengths, + audio_token_id=getattr(self.hf_config, "audio_token_id", None), + audio_start_token_id=self.audio_start_token_id, + position_id_per_seconds=getattr( + self.hf_config, "position_id_per_seconds", None + ), ) mrope_positions = mrope_positions.squeeze(1) @@ -293,6 +326,7 @@ async def process_mm_data_async( "im_end_id": self.IM_END_TOKEN_ID, "im_token_id": self.mm_tokens.image_token_id, "video_token_id": self.mm_tokens.video_token_id, + "audio_token_id": self.mm_tokens.audio_token_id, "mrope_positions": mrope_positions, "mrope_position_delta": mrope_position_delta, } diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index b6861b99c01..b8f4c64c481 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -355,9 +355,10 @@ def test_audio_ambient_completion(self): if __name__ == "__main__": del ( - TestOpenAIOmniServerBase, + TestOpenAIMLLMServerBase, ImageOpenAITestMixin, VideoOpenAITestMixin, AudioOpenAITestMixin, + OmniOpenAITestMixin, ) unittest.main() diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index 963036aee86..304896e73cd 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -241,11 +241,35 @@ def setUpClass(cls): cls.base_url += "/v1" +class TestQwen3OmniServer(OmniOpenAITestMixin): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ # workaround to fit into H100 + "--trust-remote-code", + "--mem-fraction-static", + "0.90", + "--disable-cuda-graph", + "--disable-fast-image-processor", + "--grammar-backend", + "none", + ], + ) + cls.base_url += "/v1" + + if __name__ == "__main__": del ( - TestOpenAIOmniServerBase, + TestOpenAIMLLMServerBase, ImageOpenAITestMixin, VideoOpenAITestMixin, AudioOpenAITestMixin, + OmniOpenAITestMixin, ) unittest.main() diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 6af8f099ce5..ec8a5fce302 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -1,6 +1,7 @@ import base64 import io import os +from concurrent.futures import ThreadPoolExecutor import numpy as np import openai @@ -22,7 +23,7 @@ AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3" -class TestOpenAIOmniServerBase(CustomTestCase): +class TestOpenAIMLLMServerBase(CustomTestCase): @classmethod def setUpClass(cls): cls.model = "" @@ -58,7 +59,20 @@ def get_or_download_file(self, url: str) -> str: return file_path -class AudioOpenAITestMixin(TestOpenAIOmniServerBase): +class AudioOpenAITestMixin(TestOpenAIMLLMServerBase): + def verify_speech_recognition_response(self, text): + check_list = [ + "thank you", + "it's a privilege to be here", + "leader", + "science", + "art", + ] + for check_word in check_list: + assert ( + check_word in text.lower() + ), f"audio_response: |{text}| should contain |{check_word}|" + def prepare_audio_messages(self, prompt, audio_file_name): messages = [ { @@ -116,17 +130,7 @@ def test_audio_speech_completion(self): "Listen to this audio and write down the audio transcription in English.", category="speech", ) - check_list = [ - "thank you", - "it's a privilege to be here", - "leader", - "science", - "art", - ] - for check_word in check_list: - assert ( - check_word in audio_response - ), f"audio_response: |{audio_response}| should contain |{check_word}|" + self.verify_speech_recognition_response(audio_response) def test_audio_ambient_completion(self): # bird song @@ -138,26 +142,39 @@ def test_audio_ambient_completion(self): assert "bird" in audio_response -class ImageOpenAITestMixin(TestOpenAIOmniServerBase): - def test_single_image_chat_completion(self): +class ImageOpenAITestMixin(TestOpenAIMLLMServerBase): + def run_decode_with_image(self, image_id): client = openai.Client(api_key=self.api_key, base_url=self.base_url) + content = [] + if image_id == 0: + content.append( + { + "type": "image_url", + "image_url": {"url": IMAGE_MAN_IRONING_URL}, + } + ) + elif image_id == 1: + content.append( + { + "type": "image_url", + "image_url": {"url": IMAGE_SGL_LOGO_URL}, + } + ) + else: + pass + + content.append( + { + "type": "text", + "text": "Describe this image in a sentence.", + } + ) + response = client.chat.completions.create( model="default", messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": IMAGE_MAN_IRONING_URL}, - }, - { - "type": "text", - "text": "Describe this image in a sentence.", - }, - ], - }, + {"role": "user", "content": content}, ], temperature=0, **(self.get_vision_request_kwargs()), @@ -166,6 +183,17 @@ def test_single_image_chat_completion(self): assert response.choices[0].message.role == "assistant" text = response.choices[0].message.content assert isinstance(text, str) + + def test_mixed_batch(self): + image_ids = [0, 1, 2] * 4 + with ThreadPoolExecutor(4) as executor: + list(executor.map(self.run_decode_with_image, image_ids)) + + def verify_single_image_response(self, response): + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + # `driver` is for gemma-3-it assert ( "man" in text or "person" or "driver" in text @@ -179,19 +207,44 @@ def test_single_image_chat_completion(self): ), f"text: {text}, should contain cab, taxi, SUV, vehicle or car" # MiniCPMO fails to recognize `iron`, but `hanging` assert ( - "iron" in text - or "hang" in text - or "cloth" in text - or "coat" in text - or "holding" in text - or "outfit" in text - ), f"text: {text}, should contain iron, hang, cloth, coat or holding or outfit" + "iron" in text or "hang" in text or "cloth" in text or "holding" in text + ), f"text: {text}, should contain iron, hang, cloth or holding" assert response.id assert response.created assert response.usage.prompt_tokens > 0 assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + def test_single_image_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": IMAGE_MAN_IRONING_URL}, + }, + { + "type": "text", + "text": "Describe this image in a sentence.", + }, + ], + }, + ], + temperature=0, + **(self.get_vision_request_kwargs()), + ) + + print("-" * 30) + print(f"Single image response:\n{response.choices[0].message.content}") + print("-" * 30) + + self.verify_single_image_response(response) + def test_multi_turn_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -264,8 +317,7 @@ def test_multi_images_chat_completion(self): }, { "type": "text", - "text": "I have two very different images. They are not related at all. " - "Please describe the first image in one sentence, and then describe the second image in another sentence.", + "text": "I have two very different images. Please describe them.", }, ], }, @@ -296,64 +348,6 @@ def test_multi_images_chat_completion(self): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 - def _test_mixed_image_audio_chat_completion(self): - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - - response = client.chat.completions.create( - model="default", - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": IMAGE_MAN_IRONING_URL}, - }, - { - "type": "audio_url", - "audio_url": {"url": AUDIO_TRUMP_SPEECH_URL}, - }, - { - "type": "text", - "text": "Please describe the image in one sentence, and then write down the audio transcription in English.", - }, - ], - }, - ], - temperature=0, - **(self.get_vision_request_kwargs()), - ) - - assert response.choices[0].message.role == "assistant" - text = response.choices[0].message.content - assert isinstance(text, str) - print("-" * 30) - print(f"Mixed image & audio response:\n{text}") - print("-" * 30) - assert ( - "man" in text - or "cab" in text - or "SUV" in text - or "taxi" in text - or "car" in text - ), f"text: {text}, should contain man, cab, SUV, taxi or car" - check_list = [ - "thank you", - "it's a privilege to be here", - "leader", - "science", - "art", - ] - for check_word in check_list: - assert ( - check_word in text - ), f"text: |{text}| should contain |{check_word}|" - assert response.id - assert response.created - assert response.usage.prompt_tokens > 0 - assert response.usage.completion_tokens > 0 - assert response.usage.total_tokens > 0 - def prepare_video_images_messages(self, video_path): # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa # the size of the video embeds differs from the `modality` argument when preprocessed @@ -461,7 +455,7 @@ def test_video_images_chat_completion(self): self.assertGreater(len(video_response), 0) -class VideoOpenAITestMixin(TestOpenAIOmniServerBase): +class VideoOpenAITestMixin(TestOpenAIMLLMServerBase): def prepare_video_messages(self, video_path): messages = [ { @@ -526,3 +520,45 @@ def test_video_chat_completion(self): ), f"video_response: {video_response}, should contain 'black' or 'dark'" self.assertIsNotNone(video_response) self.assertGreater(len(video_response), 0) + + +class OmniOpenAITestMixin( + ImageOpenAITestMixin, VideoOpenAITestMixin, AudioOpenAITestMixin +): + def test_mixed_modality_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": IMAGE_MAN_IRONING_URL}, + }, + { + "type": "audio_url", + "audio_url": {"url": AUDIO_TRUMP_SPEECH_URL}, + }, + { + "type": "text", + "text": "I have an image and audio, which are not related at all. Please: 1. Describe the image in a sentence, 2. Repeat the exact words from the audio I provided. Be exact", + }, + ], + }, + ] + response = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=128, + stream=False, + ) + + text = response.choices[0].message.content + + print("-" * 30) + print(f"Mixed modality response:\n{text}") + print("-" * 30) + + self.verify_single_image_response(response=response) + self.verify_speech_recognition_response(text=text)