Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -228,5 +228,8 @@ compile_commands.json

1

# Autoenv
.env.leave

# Rust lib
Cargo.lock
3 changes: 2 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle

[project.optional-dependencies]
runtime_common = [
"blobfile==3.0.0",
"compressed-tensors",
"datasets",
"fastapi",
Expand All @@ -38,12 +39,12 @@ runtime_common = [
"python-multipart",
"pyzmq>=25.1.2",
"soundfile==0.13.1",
"scipy",
"torchao==0.9.0",
"transformers==4.51.1",
"uvicorn",
"uvloop",
"xgrammar==0.1.19",
"blobfile==3.0.0"
]

srt = [
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
"Qwen2_5_VLForConditionalGeneration",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
"Phi4MMForCausalLM",
]


Expand Down
21 changes: 21 additions & 0 deletions python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,21 @@ def generate_chat_conv(
)
)

# TODO (lifuhuang): Refactor BaseMultimodalProcessor to support the default image token "<|image_{index}|>" in the future.
register_conv_template(
Conversation(
name="phi-4-mm",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
system_template="<|system|>{system_message}<|end|>",
roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="<|end|>",
stop_str="<|end|>",
image_token="<|endoftext10|>",
# image_token="<|image_{index}|>",
)
)

register_conv_template(
Conversation(
name="chatml",
Expand Down Expand Up @@ -945,3 +960,9 @@ def match_openbmb_minicpm(model_path: str):
def match_moonshot_kimivl(model_path: str):
if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
return "kimi-vl"


@register_conv_template_matching_function
def match_phi_4_mm(model_path: str):
if "phi-4-multimodal" in model_path.lower():
return "phi-4-mm"
87 changes: 87 additions & 0 deletions python/sglang/srt/managers/multimodal_processors/phi4mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging
from typing import List, Union

from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.phi4mmvllm import Phi4MMForCausalLM

logger = logging.getLogger(__name__)

_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>"
_IMAGE_SPECIAL_TOKEN_ID = 200010


class Phi4MMImageProcessor(BaseMultimodalProcessor):
models = [Phi4MMForCausalLM]

def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_IMAGE_SPECIAL_TOKEN,
)

async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data

if not image_data and not audio_data:
return None

if not isinstance(image_data, list):
image_data = [image_data]

if not isinstance(audio_data, list):
audio_data = [audio_data]

if audio_data:
logger.warning(
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
)
audio_data = []

base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=self.multimodal_tokens,
)
if base_output is None:
return None

res = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
audios=base_output.audios,
)

input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=_IMAGE_SPECIAL_TOKEN_ID,
)

items = [
MultimodalDataItem(
pixel_values=res["input_image_embeds"],
image_sizes=res["image_sizes"],
image_emb_mask=res["image_attention_mask"],
image_offsets=image_offsets,
modality=Modality.IMAGE,
)
]

return {
"mm_items": items,
"input_ids": input_ids.tolist(),
"im_token_id": _IMAGE_SPECIAL_TOKEN_ID,
}
27 changes: 22 additions & 5 deletions python/sglang/srt/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""

from functools import partial
from typing import (
Any,
Expand Down Expand Up @@ -386,6 +387,7 @@ def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
require_post_norm: bool = True,
prefix: str = "",
) -> None:
super().__init__()
Expand All @@ -398,20 +400,35 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("encoder", prefix),
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_layernorm = (
nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
if require_post_norm
else nn.Identity()
)

def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings

def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
def compute_cu_seqlens(
self,
tgt_sizes: Optional[torch.Tensor] = None,
atch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
# shape: (batch_size,)
if tgt_sizes is not None:
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
else:
patch_len = atch_attention_mask[:, :, 0].sum(dim=1) * atch_attention_mask[
:, 0, :
].sum(dim=1)

cu_seqlens = torch.cat(
[
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
torch.cumsum(patch_len, dim=0, dtype=torch.int32),
],
dim=0,
).to(tgt_sizes.device)
).to(patch_len.device)
return cu_seqlens

def forward(
Expand All @@ -425,7 +442,7 @@ def forward(
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes,
)
cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, patch_attention_mask)
encoder_outputs = self.encoder(
hidden_states,
cu_seqlens=cu_seqlens,
Expand Down
Loading
Loading