Skip to content

Commit 969bbc7

Browse files
AbatomJumiarZyann7jeejeelee
authored
[Model] Add MiMo-V2-Flash support (vllm-project#30836)
Signed-off-by: Abatom <[email protected]> Signed-off-by: Jumiar <[email protected]> Signed-off-by: Zyann7 <[email protected]> Co-authored-by: Jumiar <[email protected]> Co-authored-by: Zyann7 <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent 268a972 commit 969bbc7

File tree

8 files changed

+789
-13
lines changed

8 files changed

+789
-13
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ th {
415415
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ |
416416
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ |
417417
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ |
418+
| `MiMoV2FlashForCausalLM` | MiMoV2Flash | `XiaomiMiMo/MiMo-V2-Flash`, etc. || ✅︎ |
418419
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ |
419420
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ |
420421
| `MiniMaxM2ForCausalLM` | MiniMax-M2 |`MiniMaxAI/MiniMax-M2`, etc. | | ✅︎ |

tests/models/registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,9 @@ def check_available_online(
459459
),
460460
"Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"),
461461
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True),
462+
"MiMoV2FlashForCausalLM": _HfExamplesInfo(
463+
"XiaomiMiMo/MiMo-V2-Flash", trust_remote_code=True
464+
),
462465
"Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"),
463466
}
464467

vllm/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.config.model import (
1919
ModelConfig,
2020
iter_architecture_defaults,
21+
str_dtype_to_torch_dtype,
2122
try_match_architecture_defaults,
2223
)
2324
from vllm.config.multimodal import MultiModalConfig
@@ -72,6 +73,7 @@
7273
# From vllm.config.model
7374
"ModelConfig",
7475
"iter_architecture_defaults",
76+
"str_dtype_to_torch_dtype",
7577
"try_match_architecture_defaults",
7678
# From vllm.config.multimodal
7779
"MultiModalConfig",

vllm/config/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,6 +1849,11 @@ def try_match_architecture_defaults(
18491849
"bfloat16": torch.bfloat16,
18501850
}
18511851

1852+
1853+
def str_dtype_to_torch_dtype(type: str):
1854+
return _STR_DTYPE_TO_TORCH_DTYPE.get(type)
1855+
1856+
18521857
# model_type -> reason
18531858
_FLOAT16_NOT_SUPPORTED_MODELS = {
18541859
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",

vllm/model_executor/layers/linear.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def __init__(
277277
self.params_dtype = params_dtype
278278
self.quant_config = quant_config
279279
self.prefix = prefix
280+
self.allow_fp8_block_shape_mismatch = False
280281
if quant_config is None:
281282
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
282283
else:
@@ -475,6 +476,7 @@ def __init__(
475476
disable_tp=disable_tp,
476477
)
477478

479+
self._maybe_allow_fp8_block_shape_mismatch()
478480
self.gather_output = gather_output
479481

480482
if output_sizes is None:
@@ -509,6 +511,33 @@ def __init__(
509511
self.register_parameter("bias", None)
510512
self.update_param_tp_status()
511513

514+
def _maybe_allow_fp8_block_shape_mismatch(self) -> None:
515+
quant_config = getattr(self, "quant_config", None)
516+
weight_block = getattr(quant_config, "weight_block_size", None)
517+
if (
518+
weight_block is None
519+
or len(weight_block) < 1
520+
or len(self.output_partition_sizes) <= 1
521+
):
522+
return
523+
524+
try:
525+
block_n = int(weight_block[0])
526+
except (ValueError, TypeError):
527+
return
528+
529+
if block_n <= 0:
530+
return
531+
532+
if any(size % block_n != 0 for size in self.output_partition_sizes):
533+
self.allow_fp8_block_shape_mismatch = True
534+
logger.debug(
535+
"Allowing FP8 block shape mismatch for %s (block_n=%d, partitions=%s)",
536+
getattr(self, "prefix", "<unknown>"),
537+
block_n,
538+
self.output_partition_sizes,
539+
)
540+
512541
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
513542
output_dim = getattr(param, "output_dim", None)
514543

@@ -906,9 +935,11 @@ def __init__(
906935
*,
907936
return_bias: bool = True,
908937
disable_tp: bool = False,
938+
v_head_size: int | None = None,
909939
):
910940
self.hidden_size = hidden_size
911941
self.head_size = head_size
942+
self.v_head_size = v_head_size if v_head_size is not None else head_size
912943
self.total_num_heads = total_num_heads
913944
if total_num_kv_heads is None:
914945
total_num_kv_heads = total_num_heads
@@ -924,12 +955,14 @@ def __init__(
924955
self.num_kv_head_replicas = 1
925956
input_size = self.hidden_size
926957
output_size = (
927-
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
928-
)
958+
self.num_heads * self.head_size
959+
+ self.num_kv_heads * self.head_size
960+
+ self.num_kv_heads * self.v_head_size
961+
) * tp_size
929962
self.output_sizes = [
930963
self.num_heads * self.head_size * tp_size, # q_proj
931964
self.num_kv_heads * self.head_size * tp_size, # k_proj
932-
self.num_kv_heads * self.head_size * tp_size, # v_proj
965+
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
933966
]
934967

935968
super().__init__(
@@ -950,15 +983,16 @@ def _get_shard_offset_mapping(self, loaded_shard_id: str):
950983
"q": 0,
951984
"k": self.num_heads * self.head_size,
952985
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
953-
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
986+
"total": (self.num_heads + self.num_kv_heads) * self.head_size
987+
+ self.num_kv_heads * self.v_head_size,
954988
}
955989
return shard_offset_mapping.get(loaded_shard_id)
956990

957991
def _get_shard_size_mapping(self, loaded_shard_id: str):
958992
shard_size_mapping = {
959993
"q": self.num_heads * self.head_size,
960994
"k": self.num_kv_heads * self.head_size,
961-
"v": self.num_kv_heads * self.head_size,
995+
"v": self.num_kv_heads * self.v_head_size,
962996
}
963997
return shard_size_mapping.get(loaded_shard_id)
964998

@@ -985,7 +1019,7 @@ def _load_fused_module_from_checkpoint(
9851019
(
9861020
"v",
9871021
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
988-
self.total_num_kv_heads * self.head_size,
1022+
self.total_num_kv_heads * self.v_head_size,
9891023
),
9901024
]
9911025

@@ -1110,7 +1144,7 @@ def weight_loader(
11101144
(
11111145
"v",
11121146
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
1113-
self.total_num_kv_heads * self.head_size,
1147+
self.total_num_kv_heads * self.v_head_size,
11141148
),
11151149
]
11161150
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
@@ -1139,11 +1173,12 @@ def weight_loader(
11391173
"v": (
11401174
(self.total_num_heads + self.total_num_kv_heads)
11411175
* self.head_size,
1142-
self.total_num_kv_heads * self.head_size,
1176+
self.total_num_kv_heads * self.v_head_size,
11431177
),
11441178
"total": (
1145-
(self.total_num_heads + 2 * self.total_num_kv_heads)
1146-
* self.head_size,
1179+
(self.total_num_heads + self.total_num_kv_heads)
1180+
* self.head_size
1181+
+ self.total_num_kv_heads * self.v_head_size,
11471182
0,
11481183
),
11491184
}
@@ -1170,7 +1205,7 @@ def weight_loader(
11701205
shard_size = self.num_kv_heads * self.head_size
11711206
elif loaded_shard_id == "v":
11721207
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
1173-
shard_size = self.num_kv_heads * self.head_size
1208+
shard_size = self.num_kv_heads * self.v_head_size
11741209
# Special case for Quantized Weights.
11751210
# If quantized, we need to adjust the offset and size to account
11761211
# for the packing.
@@ -1199,10 +1234,11 @@ def weight_loader(
11991234
),
12001235
"v": (
12011236
(self.num_heads + self.num_kv_heads) * self.head_size,
1202-
self.num_kv_heads * self.head_size,
1237+
self.num_kv_heads * self.v_head_size,
12031238
),
12041239
"total": (
1205-
(self.num_heads + 2 * self.num_kv_heads) * self.head_size,
1240+
(self.num_heads + self.num_kv_heads) * self.head_size
1241+
+ self.num_kv_heads * self.v_head_size,
12061242
0,
12071243
),
12081244
}

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,14 @@ def validate_fp8_block_shape(
12521252
"""Validate block quantization shapes for tensor parallelism."""
12531253
from vllm.distributed import get_tensor_model_parallel_world_size
12541254

1255+
if getattr(layer, "allow_fp8_block_shape_mismatch", False):
1256+
logger.debug(
1257+
"Skipping FP8 block shape validation for layer %s due to detected"
1258+
" mismatch allowance.",
1259+
getattr(layer, "prefix", "<unknown>"),
1260+
)
1261+
return
1262+
12551263
tp_size = getattr(layer, "tp_size", get_tensor_model_parallel_world_size())
12561264
block_n, block_k = block_size[0], block_size[1]
12571265

0 commit comments

Comments
 (0)