diff --git a/mbridge/models/qwen3_5/base_bridge.py b/mbridge/models/qwen3_5/base_bridge.py index c6ccd26..ca8cbee 100644 --- a/mbridge/models/qwen3_5/base_bridge.py +++ b/mbridge/models/qwen3_5/base_bridge.py @@ -837,6 +837,7 @@ def provider( image_token_id=self.hf_config.image_token_id, video_token_id=self.hf_config.video_token_id, vision_start_token_id=self.hf_config.vision_start_token_id, + vp_stage=vp_stage, ) for callback in post_model_creation_callbacks: diff --git a/mbridge/models/qwen3_5/model.py b/mbridge/models/qwen3_5/model.py index 03556a7..b3253fa 100644 --- a/mbridge/models/qwen3_5/model.py +++ b/mbridge/models/qwen3_5/model.py @@ -1,3 +1,4 @@ +import inspect import logging from typing import Optional @@ -71,6 +72,7 @@ def __init__( image_token_id: int = 151655, video_token_id: int = 151656, vision_start_token_id: int = 151652, + vp_stage: Optional[int] = None, rope_scaling: bool = False, ) -> None: super().__init__(config=language_transformer_config) @@ -100,7 +102,7 @@ def __init__( self._hook_fp32_rotary_emb(self.vision_model) self._hook_vision_params_avg_grad_across_tp(self.vision_model) - self.language_model = Qwen3_5GPTModel( + language_model_kwargs = dict( config=language_transformer_config, transformer_layer_spec=language_transformer_layer_spec, vocab_size=language_vocab_size, @@ -117,6 +119,10 @@ def __init__( mtp_block_spec=language_mtp_block_spec, scatter_embedding_sequence_parallel=False, ) + if "vp_stage" in inspect.signature(GPTModel.__init__).parameters: + language_model_kwargs["vp_stage"] = vp_stage + + self.language_model = Qwen3_5GPTModel(**language_model_kwargs) @staticmethod def _hook_fp32_rotary_emb(module: torch.nn.Module):