diff --git a/README.md b/README.md
index d204fddff6..d4922e29e9 100644
--- a/README.md
+++ b/README.md
@@ -162,6 +162,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
Qwen-VL (7B)
Qwen2-VL (2B, 7B, 72B)
Qwen2.5-VL (3B, 7B, 72B)
+ Qwen3-VL (2B - 235B)
DeepSeek-VL (7B)
DeepSeek-VL2 (3B, 16B, 27B)
InternVL-Chat (v1.1-v1.5)
diff --git a/README_ja.md b/README_ja.md
index 75d05390ad..5dda14c041 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -148,6 +148,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
Qwen-VL (7B)
Qwen2-VL (2B, 7B, 72B)
Qwen2.5-VL (3B, 7B, 72B)
+ Qwen3-VL (2B - 235B)
DeepSeek-VL (7B)
DeepSeek-VL2 (3B, 16B, 27B)
InternVL-Chat (v1.1-v1.5)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index f6f10a5b42..2e5f124d20 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -163,6 +163,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
Qwen-VL (7B)
Qwen2-VL (2B, 7B, 72B)
Qwen2.5-VL (3B, 7B, 72B)
+ Qwen3-VL (2B - 235B)
DeepSeek-VL (7B)
DeepSeek-VL2 (3B, 16B, 27B)
InternVL-Chat (v1.1-v1.5)
diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md
index aa28854d8a..d7f2bffa05 100644
--- a/docs/en/supported_models/supported_models.md
+++ b/docs/en/supported_models/supported_models.md
@@ -87,6 +87,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
+| QWen3-VL | 2B - 235B | MLLM | Yes | No | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md
index 8e9e3fef20..73dd304e98 100644
--- a/docs/zh_cn/supported_models/supported_models.md
+++ b/docs/zh_cn/supported_models/supported_models.md
@@ -87,6 +87,7 @@
| Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes |
| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
| QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No |
+| QWen3-VL | 2B - 235B | MLLM | Yes | No | No | No | No |
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py
index faf5e88157..444a2026a3 100644
--- a/lmdeploy/archs.py
+++ b/lmdeploy/archs.py
@@ -109,9 +109,9 @@ def check_vl_llm(config: dict) -> bool:
'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM', 'CogVLMForCausalLM', 'InternLMXComposer2ForCausalLM',
'InternVLChatModel', 'MiniCPMV', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration',
'Phi3VForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration',
- 'MllamaForConditionalGeneration', 'MolmoForCausalLM', 'Gemma3ForConditionalGeneration',
- 'Llama4ForConditionalGeneration', 'InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration',
- 'Glm4vForConditionalGeneration'
+ 'Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration', 'MllamaForConditionalGeneration',
+ 'MolmoForCausalLM', 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration',
+ 'InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration', 'Glm4vForConditionalGeneration'
])
if arch == 'QWenLMHeadModel' and 'visual' in config:
return True
diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py
index ac3459e045..da1d27f8c7 100644
--- a/lmdeploy/pytorch/config.py
+++ b/lmdeploy/pytorch/config.py
@@ -28,9 +28,9 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
config.dtype = torch.float16
return config
- torch_dtype = getattr(config.hf_config, 'dtype', None)
+ torch_dtype = getattr(config.llm_config, 'dtype', None)
if torch_dtype is None:
- torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
+ torch_dtype = getattr(config.llm_config, 'torch_dtype', None)
# deal with case when torch_dtype is not string but torch.dtype
if isinstance(torch_dtype, torch.dtype):
diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py
index e30ae7c089..4d06cd10ce 100644
--- a/lmdeploy/pytorch/configurations/default.py
+++ b/lmdeploy/pytorch/configurations/default.py
@@ -14,8 +14,16 @@ def condition(cls, hf_config):
@classmethod
def build(cls, hf_config, model_path: str = None, **kwargs):
"""build."""
+
+ # for multi-modal models, get the language model config to build model config
+ if hasattr(hf_config, 'text_config'):
+ hf_config = hf_config.text_config
+ elif hasattr(hf_config, 'llm_config'):
+ hf_config = hf_config.llm_config
+
head_dim = getattr(hf_config, 'head_dim', None)
head_dim = head_dim or hf_config.hidden_size // hf_config.num_attention_heads
+
# head_dim should not be None
hf_config.head_dim = head_dim
num_attention_heads = hf_config.num_attention_heads
diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py
index 498e2c6554..5441e1c5d3 100644
--- a/lmdeploy/pytorch/models/module_map.py
+++ b/lmdeploy/pytorch/models/module_map.py
@@ -147,6 +147,18 @@
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration',
})
+# qwen3_vl
+MODULE_MAP.update({
+ 'Qwen3VLForConditionalGeneration':
+ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl.Qwen3VLForConditionalGeneration',
+})
+
+# qwen3_vl_moe
+MODULE_MAP.update({
+ 'Qwen3VLMoeForConditionalGeneration':
+ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration',
+})
+
# starcoder2
MODULE_MAP.update({
'Starcoder2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.Starcoder2ForCausalLM',
diff --git a/lmdeploy/pytorch/models/qwen3.py b/lmdeploy/pytorch/models/qwen3.py
index c362df2fe8..381bfb72cb 100644
--- a/lmdeploy/pytorch/models/qwen3.py
+++ b/lmdeploy/pytorch/models/qwen3.py
@@ -47,7 +47,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
head_dim,
num_kv_heads=num_key_value_heads,
v_head_size=head_dim,
- sliding_window=config.sliding_window,
+ sliding_window=getattr(config, 'sliding_window', None),
)
# o_proj
diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py
index 464953f264..d66ad10ebf 100644
--- a/lmdeploy/pytorch/models/qwen3_moe.py
+++ b/lmdeploy/pytorch/models/qwen3_moe.py
@@ -52,7 +52,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
head_dim,
num_kv_heads=num_key_value_heads,
v_head_size=head_dim,
- sliding_window=config.sliding_window,
+ sliding_window=getattr(config, 'sliding_window', None),
)
# o_proj
diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py
new file mode 100644
index 0000000000..6844c6b8d0
--- /dev/null
+++ b/lmdeploy/pytorch/models/qwen3_vl.py
@@ -0,0 +1,794 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+
+from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor
+from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.nn import LayerNorm
+from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
+
+from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3VLVisionRotaryEmbedding
+from .qwen2_5_vl import Qwen2_5_VLInputProcessor as Qwen3VLInputProcessor
+from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention
+from .qwen3 import Qwen3model
+from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin
+from .utils.model import DeployModelMixin, vlm_model
+
+
+class Qwen3VLTextRotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: PretrainedConfig, device=None):
+ super().__init__()
+ if hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get('rope_type', 'default')
+ else:
+ self.rope_type = 'default'
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ self.mrope_section = config.rope_scaling.get('mrope_section', [24, 20, 20])
+
+ def apply_interleaved_mrope(self, freqs, mrope_section):
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
+ args:
+ x: (3, bs, seq_len, head_dim // 2)
+ mrope_section: (3,)
+ returns:
+ x_t: (bs, seq_len, head_dim // 2)
+ """
+ freqs_t = freqs[0] # just overwrite the first dimension T
+ for dim, offset in enumerate((1, 2), start=1): # H, W
+ length = mrope_section[dim] * 3
+ idx = slice(offset, length, 3)
+ freqs_t[..., idx] = freqs[dim, ..., idx]
+ return freqs_t
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen3VL has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ if position_ids.ndim == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != 'mps' else 'cpu'
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Qwen3VLTextModel(Qwen3model):
+ """Text part of Qwen3VL.
+
+ not a pure text-only model, as DeepStack integrates visual features into the early hidden states.
+ """
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__(config=config, dtype=dtype, device=device)
+
+ # build rotary embedding
+ # TODO: zhouxinyu, add triton kernel for interleaved mrope
+ self.rotary_emb = Qwen3VLTextRotaryEmbedding(config, device=device)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ attn_metadata: Any = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ mrope_position_ids: torch.LongTensor = None,
+ # args for deepstack
+ visual_pos_masks: Optional[torch.Tensor] = None,
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
+ ):
+ """visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`,
+ *optional*):
+
+ The mask of the visual positions. deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): The deepstack
+ visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). The feature is extracted from the
+ different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack (
+ https://arxiv.org/abs/2406.04)
+ """
+
+ # token embedding
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+
+ # rotary embedding
+ if mrope_position_ids is None:
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
+ else:
+ mrope_position_ids = mrope_position_ids.unsqueeze(1)
+ cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)
+
+ cos, sin = cos[0], sin[0]
+ rotary_pos_emb = (cos, sin)
+
+ # decoding
+ residual = None
+ for idx, decoder_layer in enumerate(self.layers):
+ past_key_value = past_key_values[idx]
+ hidden_states, residual = decoder_layer(
+ hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_value,
+ residual=residual,
+ attn_metadata=attn_metadata,
+ )
+
+ # add visual features to the hidden states of first several layers
+ if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)):
+ hidden_states = hidden_states + residual
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[idx],
+ )
+ residual = None
+
+ # norm
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+ def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor,
+ visual_embeds: torch.Tensor):
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
+ local = torch.zeros_like(hidden_states)
+ local.masked_scatter_(visual_pos_masks, visual_embeds)
+ hidden_states += local
+ return hidden_states
+
+
+class Qwen3VLVisionPatchEmbed(nn.Module):
+
+ def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None:
+ super().__init__()
+ self.patch_size = config.patch_size
+ self.temporal_patch_size = config.temporal_patch_size
+ self.in_channels = config.in_channels
+ self.embed_dim = config.hidden_size
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels,
+ self.embed_dim,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=True,
+ dtype=dtype,
+ device=device)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size,
+ self.patch_size)
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen3VLVisionMLP(nn.Module):
+ """Vision mlp."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ from transformers.activations import ACT2FN
+ hidden_dim = config.hidden_size
+ intermediate_size = config.intermediate_size
+ quantization_config = getattr(config, 'quantization_config', None)
+ # gate up
+ self.linear_fc1 = build_colwise_linear(
+ hidden_dim,
+ intermediate_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # gelu_pytorch_tanh
+ self.act = ACT2FN[config.hidden_act]
+
+ # down
+ self.linear_fc2 = build_rowwise_linear(intermediate_size,
+ hidden_dim,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True)
+
+ def forward(self, x):
+ """forward."""
+ return self.linear_fc2(self.act(self.linear_fc1(x)))
+
+
+class Qwen3VLVisionBlock(nn.Module):
+ """Vision block."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.norm1 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)
+ self.norm2 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device)
+
+ self.attn = Qwen3VLVisionAttention(config, dtype=dtype, device=device)
+
+ self.mlp = Qwen3VLVisionMLP(config, dtype=dtype, device=device)
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Qwen3VLVisionPatchMerger(nn.Module):
+
+ def __init__(self,
+ config: PretrainedConfig,
+ use_postshuffle_norm=False,
+ dtype: torch.dtype = None,
+ device: torch.device = None) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size,
+ eps=1e-6,
+ dtype=dtype,
+ device=device)
+ self.linear_fc1 = build_colwise_linear(
+ self.hidden_size,
+ self.hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ is_tp=True,
+ )
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = build_rowwise_linear(
+ self.hidden_size,
+ config.out_hidden_size,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ is_tp=True,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return x
+
+
+@vlm_model
+class Qwen3VLVisionModel(nn.Module):
+ """Vision transformer."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.spatial_merge_size = config.spatial_merge_size
+
+ self.patch_embed = Qwen3VLVisionPatchEmbed(config=config, dtype=dtype, device=device)
+
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device)
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2, device=device)
+
+ self.blocks = nn.ModuleList(
+ [Qwen3VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)])
+ self.merger = Qwen3VLVisionPatchMerger(config=config, use_postshuffle_norm=False, dtype=dtype, device=device)
+
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
+ self.deepstack_merger_list = nn.ModuleList([
+ Qwen3VLVisionPatchMerger(config=config, use_postshuffle_norm=True, dtype=dtype, device=device)
+ for _ in range(len(config.deepstack_visual_indexes))
+ ])
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ merge_size = self.spatial_merge_size
+
+ max_hw = int(grid_thw[:, 1:].max().item())
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
+ device = freq_table.device
+
+ total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
+ pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
+
+ offset = 0
+ for num_frames, height, width in grid_thw:
+ merged_h, merged_w = height // merge_size, width // merge_size
+
+ block_rows = torch.arange(merged_h, device=device) # block row indices
+ block_cols = torch.arange(merged_w, device=device) # block col indices
+ intra_row = torch.arange(merge_size, device=device) # intra-block row offsets
+ intra_col = torch.arange(merge_size, device=device) # intra-block col offsets
+
+ # Compute full-resolution positions
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
+
+ row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+ col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
+
+ coords = torch.stack((row_idx, col_idx), dim=-1)
+
+ if num_frames > 1:
+ coords = coords.repeat(num_frames, 1)
+
+ num_tokens = coords.shape[0]
+ pos_ids[offset:offset + num_tokens] = coords
+ offset += num_tokens
+
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
+ embeddings = embeddings.flatten(1)
+ return embeddings
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in zip(grid_ts, grid_hs, grid_ws):
+ h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
+ w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.int()
+ w_idxs_floor = w_idxs.int()
+ h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+ w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ base_h = h_idxs_floor * self.num_grid_per_side
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
+
+ indices = [
+ (base_h[None].T + w_idxs_floor[None]).flatten(),
+ (base_h[None].T + w_idxs_ceil[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
+ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
+ ]
+
+ weights = [
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
+ ((1 - dh)[None].T * dw[None]).flatten(),
+ (dh[None].T * (1 - dw)[None]).flatten(),
+ (dh[None].T * dw[None]).flatten(),
+ ]
+
+ for i in range(4):
+ idx_list[i].extend(indices[i].tolist())
+ weight_list[i].extend(weights[i].tolist())
+
+ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
+ weight_tensor = torch.tensor(weight_list,
+ dtype=self.pos_embed.weight.dtype,
+ device=self.pos_embed.weight.device)
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
+
+ patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
+
+ patch_pos_embeds_permute = []
+ merge_size = self.config.spatial_merge_size
+ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
+ pos_embed = pos_embed.repeat(t, 1)
+ pos_embed = (pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size,
+ -1).permute(0, 1, 3, 2, 4, 5).flatten(0, 4))
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor,
+ pos_embeds: torch.Tensor) -> torch.Tensor:
+ """forward."""
+ hidden_states = self.patch_embed(hidden_states)
+ hidden_states = hidden_states + pos_embeds
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
+
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_merge_idx = self.deepstack_visual_indexes.index(layer_num)
+ deepstack_feature = self.deepstack_merger_list[deepstack_merge_idx](hidden_states)
+ deepstack_feature_lists.append(deepstack_feature)
+
+ hidden_states = self.merger(hidden_states)
+
+ return hidden_states, deepstack_feature_lists
+
+
+class Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin):
+ """ModelForCausalLM."""
+
+ packed_modules_mapping = {
+ 'qkv_proj': [
+ 'q_proj',
+ 'k_proj',
+ 'v_proj',
+ ],
+ 'gate_up_proj': [
+ 'gate_proj',
+ 'up_proj',
+ ],
+ }
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.ctx_mgr = ctx_mgr
+
+ # build preprocessor
+ self.input_processor = Qwen3VLInputProcessor(self.config)
+
+ # build vision model
+ self.visual = Qwen3VLVisionModel(
+ config.vision_config,
+ dtype=dtype,
+ device=device,
+ )
+
+ # build text model
+ self.language_model = Qwen3VLTextModel(config.text_config, dtype=dtype, device=device)
+
+ # build lm_head
+ self.lm_head = build_rowwise_linear(config.text_config.hidden_size,
+ config.text_config.vocab_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: List[List[torch.Tensor]],
+ attn_metadata: Any = None,
+ inputs_embeds: torch.Tensor = None,
+ mrope_position_ids: torch.Tensor = None,
+ pixel_values: torch.Tensor = None,
+ vis_cu_seqlens: torch.Tensor = None,
+ vis_pos_emb: torch.Tensor = None,
+ image_mask: torch.Tensor = None,
+ pos_embeds: torch.Tensor = None,
+ grid_thw: torch.Tensor = None,
+ **kwargs,
+ ):
+ """Model forward, return logits."""
+
+ visual_pos_masks = None
+ deepstack_visual_embeds = None
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ dtype = inputs_embeds.dtype
+ pixel_values = pixel_values.to(dtype)
+ vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))
+
+ # get image embeds and deepstack visual embeds
+ image_embeds, deepstack_visual_embeds = self.visual(pixel_values,
+ cu_seqlens=vis_cu_seqlens,
+ rotary_pos_emb=vis_pos_emb,
+ pos_embeds=pos_embeds)
+
+ # split image embeds per sample
+ split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)
+
+ # mask and scatter to create final input embeddings
+ expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)
+ inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)
+
+ visual_pos_masks = expanded_image_mask
+
+ hidden_states = self.language_model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ mrope_position_ids=mrope_position_ids,
+ # args for deepstack
+ visual_pos_masks=visual_pos_masks,
+ deepstack_visual_embeds=deepstack_visual_embeds,
+ )
+ return hidden_states
+
+ def get_logits(self, hidden_states: torch.Tensor):
+ """Compute logits of the model output."""
+ return self.lm_head(hidden_states)
+
+ def update_weights(self):
+ """Update weights."""
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.language_model.embed_tokens.weight
+
+ def get_input_embeddings(self):
+ """Get input embeddings."""
+ return self.language_model.get_input_embeddings()
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None,
+ ):
+ """Prepare input."""
+
+ # get input_ids, position_ids and attention metadatas
+ input_ids = context.input_ids
+ position_ids = context.position_ids
+ attn_metadata = context.attn_metadata
+
+ pixel_values = None
+ vis_cu_seqlens = None
+ vis_pos_emb = None
+ image_mask = None
+ grid_thw = None
+ pos_embeds = None
+ if context.input_multimodals is not None:
+ image_data = [input_mm.get('image', []) for input_mm in context.input_multimodals]
+ if len(image_data) > 0:
+ # flatten batch
+ image_data = [data for im_data in image_data for data in im_data]
+ pixel_values = torch.cat([data.data for data in image_data])
+ image_token_id = image_data[0].meta['image_token_id']
+ image_mask = input_ids == image_token_id
+ grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu()
+ vis_pos_emb = self.visual.rot_pos_emb(grid_thw)
+ pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw)
+ vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
+ grid_thw[:, 0]).to(pixel_values.device)
+ vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32)
+ vis_pos_emb = vis_pos_emb.repeat(1, 2)
+ vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin())
+
+ mrope_position_ids = getattr(context, 'mrope_position_ids', None)
+
+ # process vision embeddings
+ vision_embeddings = context.input_embeddings
+ vision_embedding_indexing = context.input_embedding_indexing
+ if vision_embeddings is not None and len(vision_embeddings) > 0:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)
+
+ # inputs of forward
+ return dict(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ mrope_position_ids=mrope_position_ids,
+ pixel_values=pixel_values,
+ vis_cu_seqlens=vis_cu_seqlens,
+ vis_pos_emb=vis_pos_emb,
+ image_mask=image_mask,
+ grid_thw=grid_thw,
+ pos_embeds=pos_embeds,
+ )
+
+ def rename_weight(self, name: str) -> str:
+ """Rename weight."""
+ if name.startswith('model.language_model.'):
+ return 'language_model.' + name[len('model.language_model.'):]
+ elif name.startswith('model.visual.'):
+ return 'visual.' + name[len('model.visual.'):]
+ elif name.startswith('model.'):
+ return name[len('model.'):]
+ return name
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """Load weights."""
+ # modify from vllm
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ('.qkv_proj', '.q_proj', 'q'),
+ ('.qkv_proj', '.k_proj', 'k'),
+ ('.qkv_proj', '.v_proj', 'v'),
+ ('.gate_up_proj', '.gate_proj', 0),
+ ('.gate_up_proj', '.up_proj', 1),
+ ]
+
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if 'rotary_emb.inv_freq' in name:
+ continue
+ if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
+ continue
+ if self.config.tie_word_embeddings and 'lm_head.weight' in name:
+ continue
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, shard_id=shard_id)
+ break
+ else:
+ if '.qkv.' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
+ """Make cudagraph buffers from forward inputs."""
+ max_tokens = graph_meta.max_tokens
+
+ input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
+ mrope_position_ids = kwargs.get('mrope_position_ids', None)
+ if mrope_position_ids is not None:
+ input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens)
+
+ return input_buffers
+
+ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs):
+ """Fill cudagraph buffers from forward inputs."""
+
+ new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs)
+
+ input_ids = kwargs.get('input_ids')
+ num_tokens = input_ids.size(-1)
+ new_batch_size = graph_meta.max_batchs
+
+ is_decoding = graph_meta.is_decoding
+ input_buffers = graph_meta.input_buffers
+ mrope_position_ids = kwargs.get('mrope_position_ids', None)
+ if mrope_position_ids is not None:
+ input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids
+ if is_decoding:
+ new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size]
+ else:
+ new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids']
+
+ return new_inputs
+
+ def _get_model_metas(self, context: StepContext):
+ """Get model metas."""
+ model_metas = context.model_metas
+ if model_metas is None:
+ batch_size = context.q_seqlens.numel()
+ return [dict(mrope_delta=0)] * batch_size
+ return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas]
+
+ def _update_model_meta_decoding(self, context: StepContext):
+ """Update model meta for decoding."""
+ model_metas = self._get_model_metas(context)
+ position_ids = context.position_ids
+
+ mrope_deltas = [meta['mrope_delta'] for meta in model_metas]
+ mrope_deltas = position_ids.new_tensor(mrope_deltas)
+ mrope_position_ids = position_ids + mrope_deltas[None]
+ mrope_position_ids = mrope_position_ids.expand(3, -1)
+
+ context.mrope_position_ids = mrope_position_ids
+ return model_metas
+
+ def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device):
+ """Get mrope ids."""
+ t, h, w = grid_thw
+ h //= 2
+ w //= 2
+ stride = torch.tensor([h * w, w, 1], device=device)[:, None]
+ size = torch.tensor([t, h, w], device=device)[:, None]
+ pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1)
+ pos_ids = pos_ids // stride % size
+ return pos_ids
+
+ def _update_model_meta_prefilling(self, context: StepContext):
+ """Update model meta for prefilling."""
+ model_metas = self._get_model_metas(context)
+ input_multimodals = context.input_multimodals
+ if input_multimodals is None:
+ input_multimodals = [None] * len(model_metas)
+ position_ids = context.position_ids
+ batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist())
+ mrope_position_ids = []
+ new_model_metas = []
+ for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals):
+ images = []
+ if input_mm is not None:
+ images = input_mm.get('image', [])
+ if model_meta is None or 'mrope_delta' not in model_meta:
+ mrope_delta = 0
+ else:
+ mrope_delta = model_meta['mrope_delta']
+
+ pos_start = pos_ids[0].item()
+ mrope_pos_ids = pos_ids + mrope_delta
+ mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone()
+ for img in images:
+ grid_thw = img.meta['grid_thw'][0].tolist()
+ _, h, w = grid_thw
+ h //= 2
+ w //= 2
+ num_pad = img.end - img.start - max(h, w)
+ mrope_delta -= num_pad
+ fill_start = img.start - pos_start
+ fill_end = img.end - pos_start
+ img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device)
+ img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1]
+ mrope_pos_ids[:, fill_end:] -= num_pad
+ mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids
+
+ mrope_position_ids.append(mrope_pos_ids)
+ new_model_metas.append(dict(mrope_delta=mrope_delta))
+
+ mrope_position_ids = torch.cat(mrope_position_ids, dim=1)
+ context.mrope_position_ids = mrope_position_ids
+
+ return new_model_metas
+
+ def update_model_metas(self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None):
+ """Update model meta."""
+ if context.is_decoding:
+ return self._update_model_meta_decoding(context)
+ else:
+ return self._update_model_meta_prefilling(context)
+
+ def get_input_processor(self) -> BaseModelInputProcessor:
+ """Get input processor."""
+ return self.input_processor
+
+
+InputMultiModalType = List[Dict[str, Any]]
diff --git a/lmdeploy/pytorch/models/qwen3_vl_moe.py b/lmdeploy/pytorch/models/qwen3_vl_moe.py
new file mode 100644
index 0000000000..1dc7e32de9
--- /dev/null
+++ b/lmdeploy/pytorch/models/qwen3_vl_moe.py
@@ -0,0 +1,234 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+from lmdeploy.pytorch.model_inputs import StepContextManager
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
+
+from .qwen3_moe import Qwen3MoeModel
+from .qwen3_vl import Qwen3VLForConditionalGeneration
+from .qwen3_vl import Qwen3VLTextRotaryEmbedding as Qwen3VLMoeTextRotaryEmbedding
+
+
+class Qwen3VLMoeTextModel(Qwen3MoeModel):
+ """Text part of Qwen3VL.
+
+ not a pure text-only model, as DeepStack integrates visual features into the early hidden states.
+ """
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__(config=config, dtype=dtype, device=device)
+
+ # build rotary embedding
+ # TODO: zhouxinyu, add triton kernel for interleaved mrope
+ self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config, device=device)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ attn_metadata: Any = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ mrope_position_ids: torch.LongTensor = None,
+ # args for deepstack
+ visual_pos_masks: Optional[torch.Tensor] = None,
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
+ ):
+ """visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`,
+ *optional*):
+
+ The mask of the visual positions. deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): The deepstack
+ visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). The feature is extracted from the
+ different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack (
+ https://arxiv.org/abs/2406.04)
+ """
+
+ # token embedding
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+
+ # rotary embedding
+ if mrope_position_ids is None:
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
+ else:
+ mrope_position_ids = mrope_position_ids.unsqueeze(1)
+ cos, sin = self.rotary_emb(hidden_states, mrope_position_ids)
+
+ cos, sin = cos[0], sin[0]
+ rotary_pos_emb = (cos, sin)
+
+ # decoding
+ residual = None
+ for idx, decoder_layer in enumerate(self.layers):
+ past_key_value = past_key_values[idx]
+ hidden_states, residual = decoder_layer(
+ hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_value,
+ residual=residual,
+ attn_metadata=attn_metadata,
+ )
+
+ # add visual features to the hidden states of first several layers
+ if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)):
+ hidden_states = hidden_states + residual
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[idx],
+ )
+ residual = None
+
+ # norm
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+ def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor,
+ visual_embeds: torch.Tensor):
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
+ local = torch.zeros_like(hidden_states)
+ local.masked_scatter_(visual_pos_masks, visual_embeds)
+ hidden_states += local
+ return hidden_states
+
+
+class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
+ """ModelForCausalLM."""
+
+ packed_modules_mapping = {
+ 'qkv_proj': [
+ 'q_proj',
+ 'k_proj',
+ 'v_proj',
+ ],
+ 'gate_up_proj': [
+ 'gate_proj',
+ 'up_proj',
+ ],
+ }
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device)
+
+ self.language_model = Qwen3VLMoeTextModel(config.text_config, dtype=dtype, device=device)
+
+ def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
+ expert_params_mapping: List):
+ """Load weight experts."""
+
+ for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ # modify from vllm qwen3vlmoe fused expert loading
+ def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
+ fused_expert_params_mapping: List):
+ """Load weight of fused expert weights."""
+ num_experts = self.config.text_config.num_experts
+
+ for (param_name, weight_name) in fused_expert_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+
+ loaded_weight = loaded_weight.transpose(-1, -2) # no bias
+ if 'gate_up' in name:
+ loaded_weight = loaded_weight.chunk(2, dim=-2)
+ w1 = loaded_weight[0]
+ w3 = loaded_weight[1]
+ for expert_id in range(num_experts):
+ load_weight(param, w1[expert_id], expert_id=expert_id, shard_id='gate')
+ load_weight(param, w3[expert_id], expert_id=expert_id, shard_id='up')
+ elif 'down' in name:
+ w2 = loaded_weight
+ for expert_id in range(num_experts):
+ load_weight(param, w2[expert_id], expert_id=expert_id, shard_id='down')
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """Load weights."""
+ # modify from vllm
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ('.qkv_proj', '.q_proj', 'q'),
+ ('.qkv_proj', '.k_proj', 'k'),
+ ('.qkv_proj', '.v_proj', 'v'),
+ ('.gate_up_proj', '.gate_proj', 0),
+ ('.gate_up_proj', '.up_proj', 1),
+ ]
+
+ # expert mapping
+ num_experts = self.config.text_config.num_experts
+ expert_params_mapping = []
+ for exp_id in range(num_experts):
+ # (param_name, weight_name, expert_id, shard_id)
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
+ expert_params_mapping += [gate_param, up_param, down_param]
+
+ # fused expert mapping
+ fused_expert_params_mapping = [
+ # (param_name, weight_name)
+ ('.experts.gate_up.weight', '.experts.gate_up_proj'),
+ ('.experts.down.weight', '.experts.down_proj'),
+ ]
+
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if 'rotary_emb.inv_freq' in name:
+ continue
+ if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
+ continue
+ if self.config.tie_word_embeddings and 'lm_head.weight' in name:
+ continue
+ name = name.replace('.block_sparse_moe.', '.mlp.')
+ if '.experts' in name:
+ is_fused_expert = ('experts.gate_up_proj' in name or 'experts.down_proj' in name)
+ if is_fused_expert:
+ self._load_weight_fused_experts(name,
+ loaded_weight,
+ params_dict,
+ fused_expert_params_mapping=fused_expert_params_mapping)
+ else:
+ self._load_weight_experts(name,
+ loaded_weight,
+ params_dict,
+ expert_params_mapping=expert_params_mapping)
+ else:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, shard_id=shard_id)
+ break
+ else:
+ if '.qkv.' in name:
+ param = params_dict[name]
+ q, k, v = param.weight_spliter(loaded_weight)
+ load_weight(param, q, shard_id='q')
+ load_weight(param, k, shard_id='k')
+ load_weight(param, v, shard_id='v')
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py
index 64f5aa12c4..995f63c9f4 100644
--- a/lmdeploy/vl/model/builder.py
+++ b/lmdeploy/vl/model/builder.py
@@ -28,6 +28,7 @@
from .phi3_vision import Phi3VisionModel # noqa F401
from .qwen import QwenVisionModel # noqa F401
from .qwen2 import Qwen2VLModel # noqa F401
+from .qwen3 import Qwen3VLModel # noqa F401
from .xcomposer2 import Xcomposer2VisionModel # noqa F401
from .yi import YiVisionModel # noqa F401
diff --git a/lmdeploy/vl/model/qwen3.py b/lmdeploy/vl/model/qwen3.py
new file mode 100644
index 0000000000..40f2bf485c
--- /dev/null
+++ b/lmdeploy/vl/model/qwen3.py
@@ -0,0 +1,128 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List, Tuple
+
+import torch
+
+from lmdeploy.vl.model.base import VISION_MODELS, VisonModel
+
+
+def check_transformers():
+ try:
+ from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration # noqa: F401
+ except ImportError:
+ raise ImportError('please install latest transformers by '
+ 'pip install git+https://github.com/huggingface/transformers.git')
+
+
+@VISION_MODELS.register_module()
+class Qwen3VLModel(VisonModel):
+ """Qwen3VL model."""
+
+ _arch = ['Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration']
+
+ def build_preprocessor(self):
+ check_transformers()
+ from transformers import AutoProcessor
+ self.processor = AutoProcessor.from_pretrained(self.model_path)
+ tokenizer = self.processor.tokenizer
+ image_token = self.processor.image_token
+ self.image_token_id = tokenizer.encode(image_token)[-1]
+
+ def preprocess(self, messages: List[Dict]) -> List[Dict]:
+ """Refer to `super().preprocess()` for spec."""
+ images = self.collect_images(messages)
+ optional_keys = {'resized_height', 'resized_width', 'min_pixels', 'max_pixels'}
+ outputs = []
+ for image, params in images:
+ image = image.convert('RGB')
+
+ item = dict(type='image', image=image)
+ item.update({key: params[key] for key in params.keys() if key in optional_keys})
+ result = self.processor.image_processor(images=image, videos=None, return_tensors='pt')
+ merge_length = self.processor.image_processor.merge_size**2
+ image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length
+ result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id))
+ outputs.append(result)
+ messages.append(dict(role='preprocess', content=outputs))
+ return messages
+
+ def build_model(self):
+ # TODO: implement for turbomind
+ pass
+
+ @torch.no_grad()
+ def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]:
+ """Extract image feature. ONLY implement it when the backend is
+ turbomind engine.
+
+ Args:
+ messages(List[Dict]): the outputs of `preprocess`
+ max_batch_size(int): the max batch size when forwarding vision
+ model
+ Return:
+ the message list with forwarding results included
+ """
+ # TODO: implement for turbomind
+ pass
+
+ @staticmethod
+ def proc_messages(messages, chat_template, sequence_start):
+ """Apply chat template to get the prompt."""
+ prompt_messages = []
+ IMAGE_TOKEN = ''
+ for message in messages:
+ if isinstance(message['content'], str):
+ prompt_messages.append(message)
+ continue
+ elif message['role'] in ['images', 'preprocess', 'forward']:
+ continue
+ n_images = len([1 for x in message['content'] if x['type'] == 'image'])
+ content = [item['text'] for item in message['content'] if item['type'] == 'text']
+ prompt = content[0]
+ if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt:
+ prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>')
+ else:
+ # Qwen2-VL-2B-Instruct will concat image and user prompt
+ # according to their order in the content list
+ # we insert image token before user prompt by default. The
+ # user can use custom image token position if they want the
+ # same decorated prompt as Qwen2-VL
+ prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \
+ n_images + prompt
+ prompt_messages.append(dict(role=message['role'], content=prompt))
+ prompt = chat_template.messages2prompt(prompt_messages, sequence_start)
+ return prompt, IMAGE_TOKEN
+
+ @staticmethod
+ def get_mrope_info(seq_len: int,
+ grid_thws: List[Tuple[int, int, int]] = None,
+ ranges: List[Tuple[int, int]] = None):
+ mrope_position_ids = [torch.arange(ranges[0][0]).expand(3, -1)]
+ st_idx = ranges[0][0]
+ for i, (grid_thw, embedding_range) in enumerate(zip(grid_thws, ranges)):
+ llm_grid_t, llm_grid_h, llm_grid_w = grid_thw
+ llm_grid_h //= 2
+ llm_grid_w //= 2
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ mrope_position_ids.append(torch.stack([t_index, h_index, w_index]) + st_idx)
+ st_idx += max(llm_grid_h, llm_grid_w)
+ if i < len(ranges) - 1:
+ text_len = ranges[i + 1][0] - ranges[i][1]
+ else:
+ text_len = seq_len - embedding_range[1]
+ mrope_position_ids.append(torch.arange(text_len).expand(3, -1) + st_idx)
+ st_idx += text_len
+ mrope_position_ids = torch.cat(mrope_position_ids, dim=-1)
+ mrope_position_delta = torch.tensor([st_idx - seq_len], dtype=torch.long)
+ return mrope_position_ids, mrope_position_delta
+
+ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
+ """Return to the information needed by pytorch engine."""
+ prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start)
+ return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start)
+
+ def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs):
+ # TODO: implement for turbomind
+ pass