Skip to content

Commit b7e951a

Browse files
authored
Feat: Support audio in Phi4-mm model (#8048)
1 parent d918ab7 commit b7e951a

File tree

11 files changed

+3332
-53
lines changed

11 files changed

+3332
-53
lines changed

docs/supported_models/multimodal_language_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ in the GitHub search bar.
3737
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3's larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. |
3838
| **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | `kimi-vl` | Kimi-VL is a multimodal model that can understand and generate text from images. |
3939
| **Mistral-Small-3.1-24B** | `mistralai/Mistral-Small-3.1-24B-Instruct-2503` | `mistral` | Mistral 3.1 is a multimodal model that can generate text from text or images input. It also supports tool calling and structured output. |
40-
| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. Currently, it supports only text and vision modalities in SGLang. |
40+
| **Phi-4-multimodal-instruct** | `microsoft/Phi-4-multimodal-instruct` | `phi-4-mm` | Phi-4-multimodal-instruct is the multimodal variant of the Phi-4-mini model, enhanced with LoRA for improved multimodal capabilities. It supports text, vision and audio modalities in SGLang. |
4141
| **MiMo-VL** (7B) | `XiaomiMiMo/MiMo-VL-7B-RL` | `mimo-vl` | Xiaomi's compact yet powerful vision-language model featuring a native resolution ViT encoder for fine-grained visual details, an MLP projector for cross-modal alignment, and the MiMo-7B language model optimized for complex reasoning tasks. |

python/sglang/srt/conversation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ def generate_chat_conv(
729729
sep="<|end|>",
730730
stop_str="<|end|>",
731731
image_token="<|endoftext10|>",
732+
audio_token="<|endoftext11|>",
732733
)
733734
)
734735

python/sglang/srt/managers/schedule_batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,10 @@ class MultimodalDataItem:
239239
# For gemma3n
240240
input_features_mask: Optional[torch.Tensor] = None
241241

242+
# For phi4-mm
243+
image_attention_mask: Optional[torch.Tensor] = None
244+
audio_attention_mask: Optional[torch.Tensor] = None
245+
242246
@staticmethod
243247
def is_empty_list(l):
244248
if l is None:

python/sglang/srt/models/phi4mm.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from sglang.srt.model_loader.weight_utils import default_weight_loader
4141
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
4242
from sglang.srt.models.llama import LlamaForCausalLM
43+
from sglang.srt.models.phi4mm_audio import AudioEmbedding
4344

4445
logger = logging.getLogger(__name__)
4546

@@ -420,16 +421,49 @@ def __init__(
420421
model_dir=config._name_or_path,
421422
)
422423

424+
if isinstance(config.embd_layer["audio_embd_layer"], dict):
425+
embedding_config = {
426+
"embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"],
427+
**config.embd_layer["audio_embd_layer"],
428+
}
429+
else:
430+
embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"]}
431+
432+
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
433+
423434
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
424435
dtype = next(self.vision_encoder.parameters()).dtype
425436
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
426-
image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0)
437+
image_attention_mask = torch.cat(
438+
[item.image_attention_mask for item in items], dim=0
439+
)
427440
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
428441
image_embeds = self.vision_encoder(
429442
pixel_values, image_sizes, image_attention_mask
430443
)
431444
return torch.cat(image_embeds).type(dtype)
432445

446+
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
447+
# (e.g. multiple examples) and the second dim is the multi-audio dim
448+
# (e.g. multiple audios in the same example)
449+
embed_tokens_extend_param = next(self.embed_tokens_extend.parameters())
450+
device = embed_tokens_extend_param.device
451+
dtype = embed_tokens_extend_param.dtype
452+
audio_embeds = [
453+
self.embed_tokens_extend(
454+
# item.feature: (num_audios_in_a_sequence, T, D)
455+
# item.audio_attention_mask: (num_audios_in_a_sequence, T, D) BoolTensor or None
456+
audio_features=item.feature.to(device).type(dtype),
457+
audio_attention_mask=(
458+
item.audio_attention_mask.to(device)
459+
if item.audio_attention_mask is not None
460+
else None
461+
),
462+
)
463+
for item in items
464+
]
465+
return torch.cat(audio_embeds).type(dtype)
466+
433467
def forward(
434468
self,
435469
input_ids: torch.Tensor,
@@ -443,6 +477,7 @@ def forward(
443477
language_model=self.language_model,
444478
data_embedding_funcs={
445479
Modality.IMAGE: self.get_image_feature,
480+
Modality.AUDIO: self.get_audio_feature,
446481
},
447482
positions=positions,
448483
)
@@ -464,6 +499,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
464499
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
465500
]
466501
prefix_mapping = {
502+
"model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.",
503+
"model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.",
504+
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
467505
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
468506
"model.": "language_model.model.",
469507
}
@@ -472,7 +510,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
472510
"img_processor.encoder.layers.26",
473511
"img_processor.head",
474512
"img_processor.post_layernorm",
475-
"audio",
476513
]
477514

478515
def _should_skip(name: str) -> bool:

0 commit comments

Comments
 (0)