diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 5a6b6e15454f..bb89f85bcc00 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -186,6 +186,28 @@ def _is_sm100(device: torch.device) -> bool: return torch.cuda.get_device_capability(device)[0] >= 10 +def _assert_sm100_scales_are_ue8m0(scale: torch.Tensor) -> None: + """On B200 (SM100) DeepGEMM only supports UE8M0 (power-of-two) scales; the float32 scales + that work on H100 (SM90) have no SM100 path. UE8M0 scales load as ``float8_e8m0fnu`` (the + loader normalizes even float32-container checkpoints like dsv4-flash-base), so a plain + ``float32`` scale here means a genuine non-UE8M0 checkpoint — fail loud rather than let + ``_coerce_sf_for_kernel`` silently round it and corrupt the output. + """ + if not _is_sm100(scale.device): + return # SM90 consumes float32 SFs directly (no UE8M0 round). + if scale.dtype != torch.float32: + return # already UE8M0 (`float8_e8m0fnu`) — kernel-ready as-is. + raise ValueError( + "DeepGEMM's Blackwell (SM100) experts kernel requires power-of-two (UE8M0) scale " + "factors, but this checkpoint's expert scales are plain float32 " + "(quantization_config.scale_fmt='float'). Rounding them to UE8M0 would scale the " + "dequantized expert weights incorrectly and silently corrupt the output. Use a " + "checkpoint quantized with scale_fmt='ue8m0', or an experts implementation that " + "consumes float32 block scales directly, e.g. " + "`model.set_experts_implementation('grouped_mm')`." + ) + + _DEEPGEMM_VISITED_DEVICES: set[int] = set() @@ -556,6 +578,7 @@ def deepgemm_fp8_fp4_experts_forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: _assert_single_device(hidden_states.device, context="experts") + _assert_sm100_scales_are_ue8m0(self.down_proj_scale_inv) if self.activation_scheme == "static": raise NotImplementedError("DeepGEMM experts dispatch does not support static activation quantization.") @@ -715,6 +738,8 @@ def deepgemm_fp8_fp4_megamoe_experts_forward( `transform_weights_for_mega_moe((gate_up, gate_up_sf), (down, down_sf))`. - `config.swiglu_limit` (optional): SwiGLU clamp; absent → unclamped. """ + _assert_sm100_scales_are_ue8m0(self.down_proj_scale_inv) + if self.gate_up_proj.dtype != torch.int8: raise RuntimeError( f"DeepGEMM Mega MoE requires FP4-packed expert weights (dtype=`int8`), got " diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 8dd8c97bc207..2ace715254c2 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -52,7 +52,11 @@ @functools.cache def _get_ue8m0_dtype() -> torch.dtype: - """Return ``torch.float8_e8m0fnu`` or raise a clear error on torch without FP8 support.""" + """Return ``torch.float8_e8m0fnu`` or raise a clear error on torch without FP8 support. + + UE8M0 scales are always stored/consumed as this single dtype — the kernels (Triton + finegrained + DeepGEMM) read it natively, and supporting the same scales in mixed + container dtypes would be a mess — so fail loudly rather than fall back.""" if not hasattr(torch, "float8_e8m0fnu"): raise RuntimeError( "scale_fmt='ue8m0' requires torch.float8_e8m0fnu, which is only available in " @@ -294,7 +298,10 @@ def __init__( sf_dtype = _get_ue8m0_dtype() if scale_fmt == "ue8m0" else torch.float32 scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0] scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1] - self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=sf_dtype)) + self.weight_scale_inv = nn.Parameter( + torch.empty(scale_out_features, scale_in_features, dtype=sf_dtype), + requires_grad=sf_dtype.is_floating_point, + ) if self.activation_scheme == "static": self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32)) @@ -871,8 +878,8 @@ def _quantize_one(self, key: str, value: torch.Tensor) -> dict[str, torch.Tensor quantized = quantized.reshape(original_shape) inv_scales = (1.0 / scales).to(torch.float32) # DeepSeek V4-style storage (`scale_fmt="ue8m0"`): round inv_scales to UE8M0-representable - # values (powers of 2) and cast to `float8_e8m0fnu` byte storage so the on-disk dtype - # matches the parameter allocation in `FP8Linear`/`FP8Experts`. + # values (powers of 2) and cast to the UE8M0 byte storage so the on-disk dtype matches the + # parameter allocation in `FP8Linear`/`FP8Experts`. if self.hf_quantizer.quantization_config.scale_fmt == "ue8m0": inv_scales = torch.pow(2.0, torch.ceil(torch.log2(inv_scales.clamp(min=torch.finfo(torch.float32).tiny)))) inv_scales = inv_scales.to(_get_ue8m0_dtype()) @@ -1045,29 +1052,3 @@ def reverse_op(self) -> ConversionOps: # checkpoint preserves the FP8 format (weight + per-block ``weight_scale_inv``) # whether the in-memory state stayed quantized or was dequantized for compute. return Fp8Quantize(self.hf_quantizer) - - -class Fp8DecodeScale(ConversionOps): - """Decode MXFP8 ``ue8m0`` per-block scales (stored as ``uint8`` exponents) into the - float32 multiplicative scales the FP8 compute path expects. - - Native MXFP8 loading (``dequantize=False``) keeps weights in ``float8_e4m3fn`` and only - needs the sibling ``*.weight_scale_inv`` tensors turned from raw E8M0 bytes into real - scales (``2 ** (byte - 127)``). Prepended to each weight converter, this op runs before - any merge/concat collapses the per-expert structure: it rewrites only the ``uint8`` scale - entries and passes weights (and already-float scales) through untouched. - """ - - def __init__(self, hf_quantizer): - self.hf_quantizer = hf_quantizer - - @staticmethod - def _decode(tensor: torch.Tensor) -> torch.Tensor: - # E8M0 stores one exponent byte per block; the real scale is ``2 ** (byte - 127)``. - return (tensor.to(torch.float32) - 127.0).exp2() if tensor.dtype == torch.uint8 else tensor - - def convert(self, input_dict: dict[str, list[torch.Tensor] | torch.Tensor], **kwargs): - return { - key: [self._decode(t) for t in value] if isinstance(value, list) else self._decode(value) - for key, value in input_dict.items() - } diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 544cb3698cb5..d6ff0d503270 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1458,6 +1458,12 @@ def tp_plan(self) -> dict[str, str]: The full tp plan for the model's modules """ if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel: + if not self._ep_plan: + raise ValueError( + f"Expert parallelism was requested (`enable_expert_parallel=True`), but " + f"`{self.__class__.__name__}` does not define an expert-parallel plan. Add a " + f"`base_model_ep_plan` to its config, or disable expert parallelism." + ) return self._ep_plan return self._tp_plan diff --git a/src/transformers/models/afmoe/configuration_afmoe.py b/src/transformers/models/afmoe/configuration_afmoe.py index bc1919e40b1c..df3980e51917 100644 --- a/src/transformers/models/afmoe/configuration_afmoe.py +++ b/src/transformers/models/afmoe/configuration_afmoe.py @@ -61,6 +61,12 @@ class AfmoeConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.router": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 200192 hidden_size: int = 2048 diff --git a/src/transformers/models/afmoe/modeling_afmoe.py b/src/transformers/models/afmoe/modeling_afmoe.py index 421119b33deb..99411a4b1ed2 100644 --- a/src/transformers/models/afmoe/modeling_afmoe.py +++ b/src/transformers/models/afmoe/modeling_afmoe.py @@ -239,7 +239,7 @@ def forward(self, hidden_states): hidden_states_flat = hidden_states.view(-1, hidden_dim) # Get routing decisions (returns flattened top-k) - router_logits, top_scores, selected_experts = self.router(hidden_states, self.expert_bias) + _, top_scores, selected_experts = self.router(hidden_states, self.expert_bias) # Process through shared experts shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim) diff --git a/src/transformers/models/afmoe/modular_afmoe.py b/src/transformers/models/afmoe/modular_afmoe.py index f3ff9f15b103..3a2d5ff42308 100644 --- a/src/transformers/models/afmoe/modular_afmoe.py +++ b/src/transformers/models/afmoe/modular_afmoe.py @@ -111,7 +111,7 @@ def forward(self, hidden_states): hidden_states_flat = hidden_states.view(-1, hidden_dim) # Get routing decisions (returns flattened top-k) - router_logits, top_scores, selected_experts = self.router(hidden_states, self.expert_bias) + _, top_scores, selected_experts = self.router(hidden_states, self.expert_bias) # Process through shared experts shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim) diff --git a/src/transformers/models/cohere2_moe/configuration_cohere2_moe.py b/src/transformers/models/cohere2_moe/configuration_cohere2_moe.py index 9c1abfc06cd7..32ef6fdf66ad 100644 --- a/src/transformers/models/cohere2_moe/configuration_cohere2_moe.py +++ b/src/transformers/models/cohere2_moe/configuration_cohere2_moe.py @@ -84,6 +84,12 @@ class Cohere2MoeConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 256000 hidden_size: int = 8192 diff --git a/src/transformers/models/cohere2_moe/modeling_cohere2_moe.py b/src/transformers/models/cohere2_moe/modeling_cohere2_moe.py index 5e2a37c43d1d..e0290b835948 100644 --- a/src/transformers/models/cohere2_moe/modeling_cohere2_moe.py +++ b/src/transformers/models/cohere2_moe/modeling_cohere2_moe.py @@ -142,6 +142,7 @@ class Cohere2MoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.expert_selection_fn = config.expert_selection_fn self.norm_topk_prob = config.norm_topk_prob self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size)) diff --git a/src/transformers/models/cohere2_moe/modular_cohere2_moe.py b/src/transformers/models/cohere2_moe/modular_cohere2_moe.py index 302def6284dd..b5d4c3b49f57 100644 --- a/src/transformers/models/cohere2_moe/modular_cohere2_moe.py +++ b/src/transformers/models/cohere2_moe/modular_cohere2_moe.py @@ -69,6 +69,7 @@ class Cohere2MoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.expert_selection_fn = config.expert_selection_fn self.norm_topk_prob = config.norm_topk_prob self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size)) diff --git a/src/transformers/models/deepseek_ocr2/configuration_deepseek_ocr2.py b/src/transformers/models/deepseek_ocr2/configuration_deepseek_ocr2.py index 2d9c601c1f2a..1a1972089aa4 100644 --- a/src/transformers/models/deepseek_ocr2/configuration_deepseek_ocr2.py +++ b/src/transformers/models/deepseek_ocr2/configuration_deepseek_ocr2.py @@ -237,12 +237,23 @@ class DeepseekOcr2TextConfig(PreTrainedConfig): attention_dropout: float | None = 0.0 mlp_bias: bool = False head_dim: int | None = None + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } + attribute_map = { + "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", + } n_group: int | None = None n_routed_experts: int = 64 n_shared_experts: int = 2 routed_scaling_factor: float = 1.0 topk_group: int | None = None topk_method: str | None = "greedy" + norm_topk_prob: bool | None = False num_experts_per_tok: int | None = None moe_intermediate_size: int = 1407 diff --git a/src/transformers/models/deepseek_ocr2/modeling_deepseek_ocr2.py b/src/transformers/models/deepseek_ocr2/modeling_deepseek_ocr2.py index 4f9b8f1374b6..71229ea637ad 100644 --- a/src/transformers/models/deepseek_ocr2/modeling_deepseek_ocr2.py +++ b/src/transformers/models/deepseek_ocr2/modeling_deepseek_ocr2.py @@ -1160,7 +1160,7 @@ class DeepseekOcr2TextExperts(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.n_routed_experts + self.num_experts = config.num_experts self.hidden_dim = config.hidden_size self.intermediate_dim = config.moe_intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) @@ -1194,48 +1194,55 @@ def forward( return final_hidden_states -class DeepseekOcr2TextMoe(nn.Module): +class DeepseekOcr2TextTopkRouter(nn.Module): def __init__(self, config: DeepseekOcr2TextConfig): super().__init__() - self.config = config - self.experts = DeepseekOcr2TextExperts(config) - self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) self.routed_scaling_factor = config.routed_scaling_factor self.topk_method = config.topk_method self.num_group = config.n_group - self.top_k = config.num_experts_per_tok self.topk_group = config.topk_group - def route_tokens_to_experts(self, router_logits): - batch_size, seq_len, hidden_dim = router_logits.shape - router_logits = router_logits.view(-1, hidden_dim) - router_logits = router_logits.softmax(dim=-1, dtype=torch.float32) + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.hidden_dim) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.softmax(dim=-1, dtype=torch.float32) if self.topk_method == "greedy": - topk_weight, topk_idx = torch.topk(router_logits, k=self.top_k, dim=-1, sorted=False) + topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) elif self.topk_method == "group_limited_greedy": - group_scores = router_logits.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values + group_scores = scores.view(-1, self.num_group, self.num_experts // self.num_group).max(dim=-1).values group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) - .expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group) - .reshape(batch_size * seq_len, -1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) ) - tmp_scores = router_logits.masked_fill(~score_mask.bool(), 0.0) - topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + scores = scores.masked_fill(~score_mask.bool(), 0.0) + topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices - topk_weight = topk_weight * self.routed_scaling_factor - return topk_idx, topk_weight + +class DeepseekOcr2TextMoe(nn.Module): + def __init__(self, config: DeepseekOcr2TextConfig): + super().__init__() + self.config = config + self.experts = DeepseekOcr2TextExperts(config) + self.gate = DeepseekOcr2TextTopkRouter(config) + self.shared_experts = DeepseekOcr2TextMLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32)) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -1330,7 +1337,9 @@ class DeepseekOcr2TextPreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - if isinstance(module, DeepseekOcr2TextExperts): + if isinstance(module, DeepseekOcr2TextTopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, DeepseekOcr2TextExperts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/deepseek_ocr2/modular_deepseek_ocr2.py b/src/transformers/models/deepseek_ocr2/modular_deepseek_ocr2.py index cddbd6f2151e..793fc0e139b9 100644 --- a/src/transformers/models/deepseek_ocr2/modular_deepseek_ocr2.py +++ b/src/transformers/models/deepseek_ocr2/modular_deepseek_ocr2.py @@ -604,7 +604,6 @@ class DeepseekOcr2TextConfig(DeepseekV2Config): # Remove unused attributes inherited from DeepseekV2Config first_k_dense_replace = AttributeError() kv_lora_rank = AttributeError() - norm_topk_prob = AttributeError() q_lora_rank = AttributeError() qk_nope_head_dim = AttributeError() qk_rope_head_dim = AttributeError() diff --git a/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py b/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py index 1b8005f8efef..e8481b0284de 100644 --- a/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py @@ -94,6 +94,16 @@ class DeepseekV2Config(PreTrainedConfig): attention_dropout: float | None = 0.0 mlp_bias: bool = False head_dim: int | None = None + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } + attribute_map = { + "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", + } first_k_dense_replace: int = 0 kv_lora_rank: int = 512 q_lora_rank: int | None = 1536 diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 3ef8266218f7..d719b3519704 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -48,7 +48,7 @@ class DeepseekV2Experts(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.n_routed_experts + self.num_experts = config.num_experts self.hidden_dim = config.hidden_size self.intermediate_dim = config.moe_intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) @@ -82,48 +82,55 @@ def forward( return final_hidden_states -class DeepseekV2Moe(nn.Module): +class DeepseekV2TopkRouter(nn.Module): def __init__(self, config: DeepseekV2Config): super().__init__() - self.config = config - self.experts = DeepseekV2Experts(config) - self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) self.routed_scaling_factor = config.routed_scaling_factor self.topk_method = config.topk_method self.num_group = config.n_group - self.top_k = config.num_experts_per_tok self.topk_group = config.topk_group - def route_tokens_to_experts(self, router_logits): - batch_size, seq_len, hidden_dim = router_logits.shape - router_logits = router_logits.view(-1, hidden_dim) - router_logits = router_logits.softmax(dim=-1, dtype=torch.float32) + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.hidden_dim) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.softmax(dim=-1, dtype=torch.float32) if self.topk_method == "greedy": - topk_weight, topk_idx = torch.topk(router_logits, k=self.top_k, dim=-1, sorted=False) + topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) elif self.topk_method == "group_limited_greedy": - group_scores = router_logits.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values + group_scores = scores.view(-1, self.num_group, self.num_experts // self.num_group).max(dim=-1).values group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) - .expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group) - .reshape(batch_size * seq_len, -1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) ) - tmp_scores = router_logits.masked_fill(~score_mask.bool(), 0.0) - topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + scores = scores.masked_fill(~score_mask.bool(), 0.0) + topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices + - topk_weight = topk_weight * self.routed_scaling_factor - return topk_idx, topk_weight +class DeepseekV2Moe(nn.Module): + def __init__(self, config: DeepseekV2Config): + super().__init__() + self.config = config + self.experts = DeepseekV2Experts(config) + self.gate = DeepseekV2TopkRouter(config) + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32)) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -459,7 +466,9 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) - if isinstance(module, DeepseekV2Experts): + if isinstance(module, DeepseekV2TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, DeepseekV2Experts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py index 5644c7dc2990..364dbf5c43b5 100644 --- a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py @@ -37,7 +37,7 @@ LlamaRotaryEmbedding, eager_attention_forward, ) -from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeExperts +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeExperts, Qwen2MoeTopKRouter logger = logging.get_logger(__name__) @@ -82,9 +82,19 @@ class DeepseekV2Config(LlamaConfig): "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } model_type = "deepseek_v2" keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", + } vocab_size: int = 32000 hidden_size: int = 4096 @@ -143,53 +153,53 @@ def apply_rotary_emb( class DeepseekV2Experts(Qwen2MoeExperts): - def __init__(self, config): - super().__init__(config) - self.num_experts = config.n_routed_experts + pass -class DeepseekV2Moe(nn.Module): +class DeepseekV2TopkRouter(Qwen2MoeTopKRouter): def __init__(self, config: DeepseekV2Config): - super().__init__() - self.config = config - self.experts = DeepseekV2Experts(config) - self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) + super().__init__(config) self.routed_scaling_factor = config.routed_scaling_factor self.topk_method = config.topk_method self.num_group = config.n_group - self.top_k = config.num_experts_per_tok self.topk_group = config.topk_group - def route_tokens_to_experts(self, router_logits): - batch_size, seq_len, hidden_dim = router_logits.shape - router_logits = router_logits.view(-1, hidden_dim) - router_logits = router_logits.softmax(dim=-1, dtype=torch.float32) + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.hidden_dim) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.softmax(dim=-1, dtype=torch.float32) if self.topk_method == "greedy": - topk_weight, topk_idx = torch.topk(router_logits, k=self.top_k, dim=-1, sorted=False) + topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) elif self.topk_method == "group_limited_greedy": - group_scores = router_logits.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values + group_scores = scores.view(-1, self.num_group, self.num_experts // self.num_group).max(dim=-1).values group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) - .expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group) - .reshape(batch_size * seq_len, -1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) ) - tmp_scores = router_logits.masked_fill(~score_mask.bool(), 0.0) - topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + scores = scores.masked_fill(~score_mask.bool(), 0.0) + topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices + - topk_weight = topk_weight * self.routed_scaling_factor - return topk_idx, topk_weight +class DeepseekV2Moe(nn.Module): + def __init__(self, config: DeepseekV2Config): + super().__init__() + self.config = config + self.experts = DeepseekV2Experts(config) + self.gate = DeepseekV2TopkRouter(config) + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32)) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -349,7 +359,9 @@ class DeepseekV2PreTrainedModel(LlamaPreTrainedModel): @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) - if isinstance(module, DeepseekV2Experts): + if isinstance(module, DeepseekV2TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, DeepseekV2Experts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py index 4178547a5ff2..4fdc494972e4 100644 --- a/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py @@ -64,8 +64,15 @@ class DeepseekV3Config(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 129280 diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 9ca55dad5580..68bc1b588937 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -139,20 +139,46 @@ def forward(self, x): class DeepseekV3TopkRouter(nn.Module): def __init__(self, config): super().__init__() - self.config = config - self.n_routed_experts = config.n_routed_experts - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + self.routed_scaling_factor = config.routed_scaling_factor + self.num_group = config.n_group + self.topk_group = config.topk_group + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_experts_implementation -class DeepseekV3NaiveMoe(nn.Module): +class DeepseekV3Experts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -196,51 +222,19 @@ class DeepseekV3MoE(nn.Module): A mixed expert module containing shared experts. """ - def __init__(self, config): + def __init__(self, config: DeepseekV3Config): super().__init__() self.config = config - self.experts = DeepseekV3NaiveMoe(config) + self.experts = DeepseekV3Experts(config) self.gate = DeepseekV3TopkRouter(config) self.shared_experts = DeepseekV3MLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -553,7 +547,7 @@ def _init_weights(self, module): if isinstance(module, DeepseekV3TopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, DeepseekV3NaiveMoe): + elif isinstance(module, DeepseekV3Experts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 9117f48eff50..22b3c84dabfe 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -13,6 +13,7 @@ from ...processing_utils import Unpack from ...utils import logging from ...utils.generic import is_flash_attention_requested +from ..deepseek_v2.modeling_deepseek_v2 import DeepseekV2Moe, DeepseekV2TopkRouter from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaForCausalLM, @@ -88,53 +89,20 @@ def yarn_get_mscale(scale=1, mscale=1): return 0.1 * mscale * math.log(scale) + 1.0 -class DeepseekV3TopkRouter(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.n_routed_experts = config.n_routed_experts - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) - - def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) - router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits - - -class DeepseekV3NaiveMoe(MixtralExperts): +class DeepseekV3TopkRouter(DeepseekV2TopkRouter): def __init__(self, config): super().__init__(config) - self.num_experts = config.num_local_experts - self.intermediate_dim = config.moe_intermediate_size - - -class DeepseekV3MoE(nn.Module): - """ - A mixed expert module containing shared experts. - """ - - def __init__(self, config): - super().__init__() - self.config = config - self.experts = DeepseekV3NaiveMoe(config) - self.gate = DeepseekV3TopkRouter(config) - self.shared_experts = DeepseekV3MLP( - config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts - ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group + del self.topk_method self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.hidden_dim) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) .topk(2, dim=-1)[0] .sum(dim=-1) ) @@ -143,27 +111,29 @@ def route_tokens_to_experts(self, router_logits): group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) + topk_weights = scores.gather(1, topk_indices) if self.norm_topk_prob: denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 topk_weights /= denominator topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights + return router_logits, topk_weights, topk_indices - def forward(self, hidden_states): - residuals = hidden_states - orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states + +class DeepseekV3Experts(MixtralExperts): + def __init__(self, config): + super().__init__(config) + self.intermediate_dim = config.moe_intermediate_size + + +class DeepseekV3MoE(DeepseekV2Moe): + """ + A mixed expert module containing shared experts. + """ class DeepseekV3Attention(nn.Module): @@ -310,7 +280,7 @@ def _init_weights(self, module): if isinstance(module, DeepseekV3TopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, DeepseekV3NaiveMoe): + elif isinstance(module, DeepseekV3Experts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/deepseek_v32/configuration_deepseek_v32.py b/src/transformers/models/deepseek_v32/configuration_deepseek_v32.py index 5ef27459d13e..f9d478c6eb4f 100644 --- a/src/transformers/models/deepseek_v32/configuration_deepseek_v32.py +++ b/src/transformers/models/deepseek_v32/configuration_deepseek_v32.py @@ -77,8 +77,17 @@ class DeepseekV32Config(PreTrainedConfig, RotaryEmbeddingConfigMixin): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } - attribute_map = {"num_local_experts": "num_experts"} + attribute_map = { + "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", + } vocab_size: int = 129280 hidden_size: int = 7168 @@ -116,7 +125,6 @@ class DeepseekV32Config(PreTrainedConfig, RotaryEmbeddingConfigMixin): index_head_dim: int = 128 index_n_heads: int = 64 mlp_bias: bool = False - num_experts: int = 256 head_dim: int = 64 first_k_dense_replace: int = 3 layer_types: list[str] | None = None diff --git a/src/transformers/models/deepseek_v32/modeling_deepseek_v32.py b/src/transformers/models/deepseek_v32/modeling_deepseek_v32.py index 46f6f6c74762..19311a4d09f8 100644 --- a/src/transformers/models/deepseek_v32/modeling_deepseek_v32.py +++ b/src/transformers/models/deepseek_v32/modeling_deepseek_v32.py @@ -492,27 +492,48 @@ def forward(self, x): class DeepseekV32TopkRouter(nn.Module): - def __init__(self, config: DeepseekV32Config): + def __init__(self, config): super().__init__() - self.config = config self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group + self.num_group = config.n_group self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_experts_implementation -class DeepseekV32NaiveMoe(nn.Module): +class DeepseekV32Experts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -556,51 +577,19 @@ class DeepseekV32MoE(nn.Module): A mixed expert module containing shared experts. """ - def __init__(self, config): + def __init__(self, config: DeepseekV32Config): super().__init__() self.config = config - self.experts = DeepseekV32NaiveMoe(config) + self.experts = DeepseekV32Experts(config) self.gate = DeepseekV32TopkRouter(config) self.shared_experts = DeepseekV32MLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -680,7 +669,7 @@ def _init_weights(self, module): if isinstance(module, DeepseekV32TopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, DeepseekV32NaiveMoe): + elif isinstance(module, DeepseekV32Experts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/deepseek_v32/modular_deepseek_v32.py b/src/transformers/models/deepseek_v32/modular_deepseek_v32.py index 77be4aa9c943..fd982a8bcb64 100644 --- a/src/transformers/models/deepseek_v32/modular_deepseek_v32.py +++ b/src/transformers/models/deepseek_v32/modular_deepseek_v32.py @@ -99,7 +99,10 @@ class DeepseekV32Config(Glm4MoeLiteConfig, RotaryEmbeddingConfigMixin): "layers.*.mlp.down_proj": "rowwise", } - attribute_map = {"num_local_experts": "num_experts"} + attribute_map = { + "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", + } vocab_size: int = 129280 hidden_size: int = 7168 @@ -137,7 +140,6 @@ class DeepseekV32Config(Glm4MoeLiteConfig, RotaryEmbeddingConfigMixin): index_head_dim: int = 128 index_n_heads: int = 64 mlp_bias: bool = False - num_experts: int = 256 head_dim: int = 64 first_k_dense_replace: int = 3 pretraining_tp = AttributeError() diff --git a/src/transformers/models/dots1/configuration_dots1.py b/src/transformers/models/dots1/configuration_dots1.py index 4d568bf4a565..5de79eea7bb5 100644 --- a/src/transformers/models/dots1/configuration_dots1.py +++ b/src/transformers/models/dots1/configuration_dots1.py @@ -70,8 +70,15 @@ class Dots1Config(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 152064 diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 95b21258ffd5..a9dfbcbfb03c 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -295,20 +295,46 @@ def forward(self, x): class Dots1TopkRouter(nn.Module): def __init__(self, config): super().__init__() - self.config = config - self.n_routed_experts = config.n_routed_experts - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + self.routed_scaling_factor = config.routed_scaling_factor + self.num_group = config.n_group + self.topk_group = config.topk_group + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_experts_implementation -class Dots1NaiveMoe(nn.Module): +class Dots1Experts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -352,51 +378,19 @@ class Dots1MoE(nn.Module): A mixed expert module containing shared experts. """ - def __init__(self, config): + def __init__(self, config: Dots1Config): super().__init__() self.config = config - self.experts = Dots1NaiveMoe(config) + self.experts = Dots1Experts(config) self.gate = Dots1TopkRouter(config) self.shared_experts = Dots1MLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -476,7 +470,7 @@ def _init_weights(self, module): if isinstance(module, Dots1TopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, Dots1NaiveMoe): + elif isinstance(module, Dots1Experts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/dots1/modular_dots1.py b/src/transformers/models/dots1/modular_dots1.py index 06402d63e28c..20907d2ddd0b 100644 --- a/src/transformers/models/dots1/modular_dots1.py +++ b/src/transformers/models/dots1/modular_dots1.py @@ -84,8 +84,15 @@ class Dots1Config(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 152064 diff --git a/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py index 0c0c0edbb760..47c7a53dff84 100644 --- a/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py @@ -82,6 +82,12 @@ class Ernie4_5_MoeConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 103424 pad_token_id: int | None = 0 diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 663cb83ae32d..833f4484bcfd 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -357,6 +357,7 @@ def __init__(self, config): self.weight = nn.Parameter(torch.zeros(config.moe_num_experts, config.hidden_size, dtype=torch.float32)) self.moe_statics = Ernie4_5_MoeStatics(config) self.top_k = config.moe_k + self.num_experts = config.moe_num_experts self.norm_min = config.moe_norm_min def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -375,7 +376,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) routing_weights = routing_weights.to(hidden_states.dtype) - return router_logits, selected_experts, routing_weights + return router_logits, routing_weights, selected_experts class Ernie4_5_MoeSparseMoeBlock(nn.Module): @@ -398,7 +399,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) - _, top_k_index, top_k_weights = self.gate(hidden_states) + _, top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) if self.shared_experts is not None: diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index 0fdd889e7171..45f8a3a4e75d 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -111,6 +111,7 @@ def __init__(self, config): self.weight = nn.Parameter(torch.zeros(config.moe_num_experts, config.hidden_size, dtype=torch.float32)) self.moe_statics = Ernie4_5_MoeStatics(config) self.top_k = config.moe_k + self.num_experts = config.moe_num_experts self.norm_min = config.moe_norm_min def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -129,7 +130,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) routing_weights = routing_weights.to(hidden_states.dtype) - return router_logits, selected_experts, routing_weights + return router_logits, routing_weights, selected_experts class Ernie4_5_MoeSparseMoeBlock(nn.Module): @@ -152,7 +153,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) - _, top_k_index, top_k_weights = self.gate(hidden_states) + _, top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) if self.shared_experts is not None: diff --git a/src/transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py index e4eea836f107..e93297a4eb5a 100644 --- a/src/transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py @@ -99,6 +99,12 @@ class Ernie4_5_VLMoeTextConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 103424 pad_token_id: int | None = None diff --git a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py index e543a0bf198b..8dc1c5f25a84 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py @@ -347,6 +347,7 @@ def __init__(self, config): self.weight = nn.Parameter(torch.zeros(config.moe_num_experts, config.hidden_size, dtype=torch.float32)) self.moe_statics = Ernie4_5_VLMoeMoeStatics(config) self.top_k = config.moe_k + self.num_experts = config.moe_num_experts self.norm_min = config.moe_norm_min def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -365,7 +366,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) routing_weights = routing_weights.to(hidden_states.dtype) - return router_logits, selected_experts, routing_weights + return router_logits, routing_weights, selected_experts @use_experts_implementation @@ -423,7 +424,7 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hidden_states = hidden_states.view(-1, self.hidden_dim) - router_logits, top_k_index, top_k_weights = self.gate(hidden_states) + router_logits, top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) # moe results are changed to a flattened shape to ease the modality isolated assigning of results diff --git a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py index c970757313a3..0ddb8564170f 100644 --- a/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +++ b/src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py @@ -397,7 +397,7 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hidden_states = hidden_states.view(-1, self.hidden_dim) - router_logits, top_k_index, top_k_weights = self.gate(hidden_states) + router_logits, top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) # moe results are changed to a flattened shape to ease the modality isolated assigning of results diff --git a/src/transformers/models/exaone_moe/configuration_exaone_moe.py b/src/transformers/models/exaone_moe/configuration_exaone_moe.py index 81ded9366cdb..16ab2e2f4028 100644 --- a/src/transformers/models/exaone_moe/configuration_exaone_moe.py +++ b/src/transformers/models/exaone_moe/configuration_exaone_moe.py @@ -102,6 +102,16 @@ class ExaoneMoeConfig(PreTrainedConfig): sliding_window: int = 4096 sliding_window_pattern: str | int | None = 4 layer_types: list[str] | None = None + + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } + attribute_map = { + "num_local_experts": "num_experts", + } mlp_layer_types: list[str] | None = None first_k_dense_replace: int = 1 moe_intermediate_size: int = 1024 diff --git a/src/transformers/models/exaone_moe/modeling_exaone_moe.py b/src/transformers/models/exaone_moe/modeling_exaone_moe.py index a7f80fc979c4..a125a707168f 100644 --- a/src/transformers/models/exaone_moe/modeling_exaone_moe.py +++ b/src/transformers/models/exaone_moe/modeling_exaone_moe.py @@ -23,8 +23,8 @@ from typing import Optional import torch -import torch.nn as nn import torch.nn.functional as F +from torch import nn from ... import initialization as init from ...activations import ACT2FN @@ -227,14 +227,42 @@ def forward(self, x): class ExaoneMoeTopkRouter(nn.Module): def __init__(self, config): super().__init__() - self.config = config - self.weight = nn.Parameter(torch.empty((config.num_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros(config.num_experts)) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + self.routed_scaling_factor = config.routed_scaling_factor + self.num_group = config.n_group + self.topk_group = config.topk_group + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_experts_implementation @@ -243,7 +271,7 @@ class ExaoneMoeExperts(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.num_experts + self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size self.intermediate_dim = config.moe_intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) @@ -290,43 +318,11 @@ def __init__(self, config): self.shared_experts = ExaoneMoeMLP( config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts ) - self.n_routed_experts = config.num_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) diff --git a/src/transformers/models/exaone_moe/modular_exaone_moe.py b/src/transformers/models/exaone_moe/modular_exaone_moe.py index 75ec2b0bfd27..eee97bf7fdda 100644 --- a/src/transformers/models/exaone_moe/modular_exaone_moe.py +++ b/src/transformers/models/exaone_moe/modular_exaone_moe.py @@ -15,7 +15,6 @@ """LG AI Research EXAONE Lab""" import torch -import torch.nn as nn from huggingface_hub.dataclasses import strict from ... import initialization as init @@ -25,8 +24,8 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring from ..deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3Experts, DeepseekV3MoE, - DeepseekV3NaiveMoe, DeepseekV3TopkRouter, ) from ..exaone4.configuration_exaone4 import Exaone4Config @@ -80,6 +79,16 @@ class ExaoneMoeConfig(Exaone4Config): >>> configuration = model.config ```""" + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } + attribute_map = { + "num_local_experts": "num_experts", + } + vocab_size: int = 102400 hidden_size: int = 4096 intermediate_size: int = 16384 @@ -129,17 +138,11 @@ class ExaoneMoeMLP(Qwen2MoeMLP): class ExaoneMoeTopkRouter(DeepseekV3TopkRouter): - def __init__(self, config): - nn.Module.__init__() - self.config = config - self.weight = nn.Parameter(torch.empty((config.num_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros(config.num_experts)) + pass -class ExaoneMoeExperts(DeepseekV3NaiveMoe): - def __init__(self, config): - super().__init__(config) - self.num_experts = config.num_experts +class ExaoneMoeExperts(DeepseekV3Experts): + pass class ExaoneMoeSparseMoEBlock(DeepseekV3MoE): @@ -149,7 +152,6 @@ def __init__(self, config): self.shared_experts = ExaoneMoeMLP( config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts ) - self.n_routed_experts = config.num_experts class ExaoneMoeDecoderLayer(OlmoeDecoderLayer): diff --git a/src/transformers/models/flex_olmo/configuration_flex_olmo.py b/src/transformers/models/flex_olmo/configuration_flex_olmo.py index 7b08a79b801b..b4fff038c4a4 100644 --- a/src/transformers/models/flex_olmo/configuration_flex_olmo.py +++ b/src/transformers/models/flex_olmo/configuration_flex_olmo.py @@ -63,6 +63,12 @@ class FlexOlmoConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 100352 hidden_size: int = 4096 diff --git a/src/transformers/models/flex_olmo/modular_flex_olmo.py b/src/transformers/models/flex_olmo/modular_flex_olmo.py index 01f32227f31f..5bc7e0ad828f 100644 --- a/src/transformers/models/flex_olmo/modular_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modular_flex_olmo.py @@ -73,6 +73,12 @@ class FlexOlmoConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 100352 hidden_size: int = 4096 diff --git a/src/transformers/models/glm4_moe/configuration_glm4_moe.py b/src/transformers/models/glm4_moe/configuration_glm4_moe.py index a18123e90b33..d6d1b329e443 100644 --- a/src/transformers/models/glm4_moe/configuration_glm4_moe.py +++ b/src/transformers/models/glm4_moe/configuration_glm4_moe.py @@ -17,6 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from huggingface_hub.dataclasses import strict from ...configuration_utils import PreTrainedConfig @@ -73,8 +74,15 @@ class Glm4MoeConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 151552 diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index 53e1abf53c02..0ffc67286ba0 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -17,7 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from collections.abc import Callable from typing import Optional @@ -285,23 +284,44 @@ def forward(self, x): class Glm4MoeTopkRouter(nn.Module): - def __init__(self, config: Glm4MoeConfig): + def __init__(self, config): super().__init__() - self.config = config self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group + self.num_group = config.n_group self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_kernel_forward_from_hub("RMSNorm") @@ -326,7 +346,7 @@ def extra_repr(self): @use_experts_implementation -class Glm4MoeNaiveMoe(nn.Module): +class Glm4MoeExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -370,51 +390,19 @@ class Glm4MoeMoE(nn.Module): A mixed expert module containing shared experts. """ - def __init__(self, config): + def __init__(self, config: Glm4MoeConfig): super().__init__() self.config = config - self.experts = Glm4MoeNaiveMoe(config) + self.experts = Glm4MoeExperts(config) self.gate = Glm4MoeTopkRouter(config) self.shared_experts = Glm4MoeMLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -494,7 +482,7 @@ def _init_weights(self, module): if isinstance(module, Glm4MoeTopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, Glm4MoeNaiveMoe): + elif isinstance(module, Glm4MoeExperts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/glm4_moe/modular_glm4_moe.py b/src/transformers/models/glm4_moe/modular_glm4_moe.py index 868018d744b5..12353f7541be 100644 --- a/src/transformers/models/glm4_moe/modular_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modular_glm4_moe.py @@ -13,7 +13,6 @@ # limitations under the License. """PyTorch GLM-4-MOE model.""" -import torch from huggingface_hub.dataclasses import strict from torch import nn @@ -86,8 +85,15 @@ class Glm4MoeConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 151552 @@ -161,18 +167,7 @@ class Glm4MoeMLP(DeepseekV3MLP): class Glm4MoeTopkRouter(DeepseekV3TopkRouter): - def __init__(self, config: Glm4MoeConfig): - nn.Module.__init__(self) - self.config = config - self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts - self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + pass class Glm4MoeRMSNorm(DeepseekV3RMSNorm): diff --git a/src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py b/src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py index a9518ed9c5d3..4c68af6b541d 100644 --- a/src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +++ b/src/transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py @@ -68,8 +68,15 @@ class Glm4MoeLiteConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", "head_dim": "qk_rope_head_dim", } diff --git a/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py b/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py index e0d7b8fddb14..da043449efb1 100644 --- a/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +++ b/src/transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py @@ -362,23 +362,44 @@ def forward(self, x): class Glm4MoeLiteTopkRouter(nn.Module): - def __init__(self, config: Glm4MoeLiteConfig): + def __init__(self, config): super().__init__() - self.config = config self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group + self.num_group = config.n_group self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_kernel_forward_from_hub("RMSNorm") @@ -403,7 +424,7 @@ def extra_repr(self): @use_experts_implementation -class Glm4MoeLiteNaiveMoe(nn.Module): +class Glm4MoeLiteExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -447,51 +468,19 @@ class Glm4MoeLiteMoE(nn.Module): A mixed expert module containing shared experts. """ - def __init__(self, config): + def __init__(self, config: Glm4MoeLiteConfig): super().__init__() self.config = config - self.experts = Glm4MoeLiteNaiveMoe(config) + self.experts = Glm4MoeLiteExperts(config) self.gate = Glm4MoeLiteTopkRouter(config) self.shared_experts = Glm4MoeLiteMLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -570,7 +559,7 @@ def _init_weights(self, module): if isinstance(module, Glm4MoeLiteTopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, Glm4MoeLiteNaiveMoe): + elif isinstance(module, Glm4MoeLiteExperts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py b/src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py index 1f65f44a525a..097173ab229a 100644 --- a/src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +++ b/src/transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py @@ -22,11 +22,11 @@ from ..deepseek_v3.modeling_deepseek_v3 import DeepseekV3Attention from ..glm4_moe.modeling_glm4_moe import ( Glm4MoeDecoderLayer, + Glm4MoeExperts, Glm4MoeForCausalLM, Glm4MoeMLP, Glm4MoeModel, Glm4MoeMoE, - Glm4MoeNaiveMoe, Glm4MoePreTrainedModel, Glm4MoeRMSNorm, Glm4MoeRotaryEmbedding, @@ -76,8 +76,15 @@ class Glm4MoeLiteConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", "head_dim": "qk_rope_head_dim", } @@ -144,7 +151,7 @@ class Glm4MoeLiteRMSNorm(Glm4MoeRMSNorm): pass -class Glm4MoeLiteNaiveMoe(Glm4MoeNaiveMoe): +class Glm4MoeLiteExperts(Glm4MoeExperts): pass diff --git a/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py b/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py index 0e4d6a9cb191..1e9aa238cc7e 100644 --- a/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/configuration_glm4v_moe.py @@ -66,8 +66,15 @@ class Glm4vMoeTextConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 151424 diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 3c7d758e538c..50dcd6b2ebbe 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -54,7 +54,7 @@ maybe_autocast, merge_with_config_defaults, ) -from ...utils.output_capturing import capture_outputs +from ...utils.output_capturing import OutputRecorder, capture_outputs from ...vision_utils import get_vision_cu_seqlens, get_vision_position_ids from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig @@ -212,25 +212,46 @@ def forward( class Glm4vMoeTextTopkRouter(nn.Module): def __init__(self, config: Glm4vMoeTextConfig): super().__init__() - self.config = config self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group + self.num_group = config.n_group self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_experts_implementation -class Glm4vMoeTextNaiveMoe(nn.Module): +class Glm4vMoeTextExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -277,48 +298,16 @@ class Glm4vMoeTextMoE(nn.Module): def __init__(self, config: Glm4vMoeTextConfig): super().__init__() self.config = config - self.experts = Glm4vMoeTextNaiveMoe(config) + self.experts = Glm4vMoeTextExperts(config) self.gate = Glm4vMoeTextTopkRouter(config) self.shared_experts = Glm4vMoeTextMLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -433,7 +422,7 @@ def _init_weights(self, module): if isinstance(module, Glm4vMoeTextTopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, Glm4vMoeTextNaiveMoe): + elif isinstance(module, Glm4vMoeTextExperts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) if isinstance(module, Glm4vMoeVisionRotaryEmbedding): @@ -932,7 +921,7 @@ class Glm4vMoeTextModel(Glm4vMoePreTrainedModel): _can_record_outputs = { "hidden_states": Glm4vMoeTextDecoderLayer, "attentions": Glm4vMoeTextAttention, - "router_logits": Glm4vMoeTextTopkRouter, + "router_logits": OutputRecorder(Glm4vMoeTextTopkRouter, index=0), } def __init__(self, config: Glm4vMoeTextConfig): diff --git a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py index b1c1293a00b9..3371cff667a2 100644 --- a/src/transformers/models/glm4v_moe/modular_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modular_glm4v_moe.py @@ -26,7 +26,8 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging from ...utils.generic import can_return_tuple -from ..deepseek_v3.modeling_deepseek_v3 import DeepseekV3NaiveMoe +from ...utils.output_capturing import OutputRecorder +from ..deepseek_v3.modeling_deepseek_v3 import DeepseekV3Experts from ..glm4.modeling_glm4 import Glm4Attention from ..glm4_moe.configuration_glm4_moe import Glm4MoeConfig from ..glm4_moe.modeling_glm4_moe import ( @@ -195,7 +196,7 @@ def __init__(self, config: Glm4vMoeTextConfig): super().__init__(config) -class Glm4vMoeTextNaiveMoe(DeepseekV3NaiveMoe): +class Glm4vMoeTextExperts(DeepseekV3Experts): pass @@ -203,7 +204,7 @@ class Glm4vMoeTextMoE(Glm4MoeMoE): def __init__(self, config: Glm4vMoeTextConfig): super().__init__(config) self.config = config - self.experts = Glm4vMoeTextNaiveMoe(config) + self.experts = Glm4vMoeTextExperts(config) self.gate = Glm4vMoeTextTopkRouter(config) self.shared_experts = Glm4vMoeTextMLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts @@ -252,7 +253,7 @@ class Glm4vMoeTextModel(Glm4vTextModel): _can_record_outputs = { "hidden_states": Glm4vMoeTextDecoderLayer, "attentions": Glm4vMoeTextAttention, - "router_logits": Glm4vMoeTextTopkRouter, + "router_logits": OutputRecorder(Glm4vMoeTextTopkRouter, index=0), } def forward( diff --git a/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py index 5ecf328bd170..e2693981c192 100644 --- a/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py @@ -82,9 +82,16 @@ class GlmMoeDsaConfig(PreTrainedConfig, RotaryEmbeddingConfigMixin): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 154880 @@ -123,7 +130,6 @@ class GlmMoeDsaConfig(PreTrainedConfig, RotaryEmbeddingConfigMixin): index_head_dim: int = 128 index_n_heads: int = 32 mlp_bias: bool = False - num_experts: int = 256 head_dim: int = 64 first_k_dense_replace: int = 3 layer_types: list[str] | None = None diff --git a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py index 34d7bf139c36..019fd7261311 100644 --- a/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modeling_glm_moe_dsa.py @@ -508,27 +508,48 @@ def forward(self, x): class GlmMoeDsaTopkRouter(nn.Module): - def __init__(self, config: GlmMoeDsaConfig): + def __init__(self, config): super().__init__() - self.config = config self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group + self.num_group = config.n_group self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_experts_implementation -class GlmMoeDsaNaiveMoe(nn.Module): +class GlmMoeDsaExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -572,51 +593,19 @@ class GlmMoeDsaMoE(nn.Module): A mixed expert module containing shared experts. """ - def __init__(self, config): + def __init__(self, config: GlmMoeDsaConfig): super().__init__() self.config = config - self.experts = GlmMoeDsaNaiveMoe(config) + self.experts = GlmMoeDsaExperts(config) self.gate = GlmMoeDsaTopkRouter(config) self.shared_experts = GlmMoeDsaMLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -697,7 +686,7 @@ def _init_weights(self, module): if isinstance(module, GlmMoeDsaTopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, GlmMoeDsaNaiveMoe): + elif isinstance(module, GlmMoeDsaExperts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py index a4e91beb4633..72b2095376eb 100644 --- a/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py +++ b/src/transformers/models/glm_moe_dsa/modular_glm_moe_dsa.py @@ -78,10 +78,6 @@ class GlmMoeDsaConfig(DeepseekV32Config): >>> configuration = model.config ```""" - attribute_map = { - "num_local_experts": "n_routed_experts", - } - vocab_size: int = 154880 hidden_size: int = 6144 intermediate_size: int = 12288 diff --git a/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py index 27c6fb149aa4..f72c6155c5ea 100644 --- a/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py @@ -33,6 +33,12 @@ class HunYuanMoEV1Config(PreTrainedConfig): model_type = "hunyuan_v1_moe" keys_to_ignore_at_inference = ["past_key_values"] + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_experts_per_tok": "moe_topk", "num_local_experts": "num_experts", diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 19779da0528c..310e32e3528a 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -231,16 +231,19 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int | None = None): super().__init__() self.config = config self.layer_idx = layer_idx - num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] - self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32) + self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] + self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] + self.wg = nn.Linear(config.hidden_size, self.num_experts, bias=False, dtype=torch.float32) def forward(self, hidden_states): - bsz, seq_len, hidden_size = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_size) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() - logits = self.wg(hidden_states) - return logits + router_logits = self.wg(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + return router_logits, routing_weights.to(router_logits.dtype), selected_experts @use_experts_implementation @@ -294,18 +297,11 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int | None = None): self.experts = HunYuanMoEV1Experts(config) self.shared_mlp = HunYuanMoEV1MLP(config) - def route_tokens_to_experts(self, hidden_states): - routing_weights = F.softmax(hidden_states, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - return selected_experts, routing_weights.to(hidden_states.dtype) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_mlp = self.shared_mlp(hidden_states) - router_logits = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_dim) - selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) + _, routing_weights, selected_experts = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights).reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index 2a60c26bc391..231e8acfe596 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -111,16 +111,19 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int | None = None): super().__init__() self.config = config self.layer_idx = layer_idx - num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] - self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32) + self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] + self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] + self.wg = nn.Linear(config.hidden_size, self.num_experts, bias=False, dtype=torch.float32) def forward(self, hidden_states): - bsz, seq_len, hidden_size = hidden_states.shape - hidden_states = hidden_states.reshape(-1, hidden_size) + hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() - logits = self.wg(hidden_states) - return logits + router_logits = self.wg(hidden_states) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + return router_logits, routing_weights.to(router_logits.dtype), selected_experts class HunYuanMoEV1Experts(MixtralExperts): @@ -138,18 +141,11 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int | None = None): self.experts = HunYuanMoEV1Experts(config) self.shared_mlp = HunYuanMoEV1MLP(config) - def route_tokens_to_experts(self, hidden_states): - routing_weights = F.softmax(hidden_states, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - return selected_experts, routing_weights.to(hidden_states.dtype) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_mlp = self.shared_mlp(hidden_states) - router_logits = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_dim) - selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) + _, routing_weights, selected_experts = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights).reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/hy_v3/configuration_hy_v3.py b/src/transformers/models/hy_v3/configuration_hy_v3.py index e2eee94b118a..4399a0f4d546 100644 --- a/src/transformers/models/hy_v3/configuration_hy_v3.py +++ b/src/transformers/models/hy_v3/configuration_hy_v3.py @@ -72,6 +72,12 @@ class HYV3Config(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 120832 hidden_size: int = 4096 diff --git a/src/transformers/models/hy_v3/modular_hy_v3.py b/src/transformers/models/hy_v3/modular_hy_v3.py index fa0931435197..c24758bb5100 100644 --- a/src/transformers/models/hy_v3/modular_hy_v3.py +++ b/src/transformers/models/hy_v3/modular_hy_v3.py @@ -97,6 +97,12 @@ class HYV3Config(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 120832 hidden_size: int = 4096 diff --git a/src/transformers/models/laguna/configuration_laguna.py b/src/transformers/models/laguna/configuration_laguna.py index 6dae15bd1970..94650b384541 100644 --- a/src/transformers/models/laguna/configuration_laguna.py +++ b/src/transformers/models/laguna/configuration_laguna.py @@ -30,12 +30,12 @@ @strict class LagunaConfig(PreTrainedConfig): r""" - num_attention_heads_per_layer (`list[int]`, *optional*): - Per-layer override for ``num_attention_heads``. Length must equal ``num_hidden_layers``. gating (`bool` or `str`, *optional*, defaults to `True`): Softplus output-gate granularity. ``True`` or ``"per-head"`` applies one gate per head, broadcast across ``head_dim``; ``"per-element"`` applies one gate per ``(head, head_dim)`` channel. + num_attention_heads_per_layer (`list[int]`, *optional*): + Per-layer override for ``num_attention_heads``. Length must equal ``num_hidden_layers``. mlp_layer_types (`list[str]`, *optional*): Per-layer MLP type — ``"dense"`` or ``"sparse"``. Length must equal ``num_hidden_layers``. Defaults to first layer dense, rest sparse. @@ -83,6 +83,12 @@ class LagunaConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 100352 hidden_size: int = 2048 diff --git a/src/transformers/models/laguna/modular_laguna.py b/src/transformers/models/laguna/modular_laguna.py index 659cefb8a9ce..145e44266b86 100644 --- a/src/transformers/models/laguna/modular_laguna.py +++ b/src/transformers/models/laguna/modular_laguna.py @@ -49,12 +49,12 @@ @strict class LagunaConfig(Qwen2MoeConfig): r""" - num_attention_heads_per_layer (`list[int]`, *optional*): - Per-layer override for ``num_attention_heads``. Length must equal ``num_hidden_layers``. gating (`bool` or `str`, *optional*, defaults to `True`): Softplus output-gate granularity. ``True`` or ``"per-head"`` applies one gate per head, broadcast across ``head_dim``; ``"per-element"`` applies one gate per ``(head, head_dim)`` channel. + num_attention_heads_per_layer (`list[int]`, *optional*): + Per-layer override for ``num_attention_heads``. Length must equal ``num_hidden_layers``. mlp_layer_types (`list[str]`, *optional*): Per-layer MLP type — ``"dense"`` or ``"sparse"``. Length must equal ``num_hidden_layers``. Defaults to first layer dense, rest sparse. diff --git a/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py b/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py index b5ae5c76bdde..7b940d740aaa 100644 --- a/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py @@ -48,6 +48,12 @@ class Lfm2MoeConfig(PreTrainedConfig): model_type = "lfm2_moe" keys_to_ignore_at_inference = ["past_key_values"] + base_model_ep_plan = { + "layers.*.feed_forward.gate": "ep_router", + "layers.*.feed_forward.experts.gate_up_proj": "grouped_gemm", + "layers.*.feed_forward.experts.down_proj": "grouped_gemm", + "layers.*.feed_forward.experts": "moe_tp_experts", + } default_theta = 1000000.0 vocab_size: int = 65536 diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 159e82f95147..a2d221050a68 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -192,23 +192,22 @@ def forward( return final_hidden_states -class Lfm2MoeSparseMoeBlock(nn.Module): +class Lfm2MoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok - self.routed_scaling_factor = config.routed_scaling_factor + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + self.routed_scaling_factor = config.routed_scaling_factor self.use_expert_bias = config.use_expert_bias - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = Lfm2MoeExperts(config) - if self.use_expert_bias: - self.register_buffer("expert_bias", torch.zeros(config.num_experts, dtype=torch.float32)) - - def route_tokens_to_experts(self, router_logits): + def forward(self, hidden_states, expert_bias=None): + router_logits = F.linear(hidden_states, self.weight) routing_weights = router_logits.sigmoid() if self.use_expert_bias: - scores_for_routing = routing_weights + self.expert_bias + scores_for_routing = routing_weights + expert_bias _, selected_experts = torch.topk(scores_for_routing, k=self.top_k, dim=-1) routing_weights = torch.gather(routing_weights, dim=1, index=selected_experts).type_as(router_logits) else: @@ -217,13 +216,23 @@ def route_tokens_to_experts(self, router_logits): if self.norm_topk_prob: routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-6) routing_weights = routing_weights * self.routed_scaling_factor - return selected_experts, routing_weights + return router_logits, routing_weights, selected_experts + + +class Lfm2MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.experts = Lfm2MoeExperts(config) + self.gate = Lfm2MoeTopKRouter(config) + self.use_expert_bias = config.use_expert_bias + if self.use_expert_bias: + self.register_buffer("expert_bias", torch.zeros(config.num_experts, dtype=torch.float32)) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) + expert_bias = self.expert_bias if self.use_expert_bias else None + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped, expert_bias) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -563,6 +572,8 @@ def _init_weights(self, module): if isinstance(module, Lfm2MoeExperts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Lfm2MoeTopKRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, Lfm2MoeSparseMoeBlock): if module.use_expert_bias: init.zeros_(module.expert_bias) diff --git a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py index a1b2799d2bae..76ce9e993166 100644 --- a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py @@ -33,7 +33,8 @@ ) from ..llama.modeling_llama import LlamaForCausalLM, LlamaPreTrainedModel, LlamaRMSNorm from ..mixtral.modeling_mixtral import MixtralModel -from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeExperts +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeExperts, Qwen2MoeTopKRouter +from ..qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock from .configuration_lfm2_moe import Lfm2MoeConfig @@ -74,23 +75,17 @@ def __init__(self, config): self.act_fn = F.silu -class Lfm2MoeSparseMoeBlock(nn.Module): +class Lfm2MoeTopKRouter(Qwen2MoeTopKRouter): def __init__(self, config): - super().__init__() - self.top_k = config.num_experts_per_tok + super().__init__(config) self.routed_scaling_factor = config.routed_scaling_factor - self.norm_topk_prob = config.norm_topk_prob self.use_expert_bias = config.use_expert_bias - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = Lfm2MoeExperts(config) - if self.use_expert_bias: - self.register_buffer("expert_bias", torch.zeros(config.num_experts, dtype=torch.float32)) - - def route_tokens_to_experts(self, router_logits): + def forward(self, hidden_states, expert_bias=None): + router_logits = F.linear(hidden_states, self.weight) routing_weights = router_logits.sigmoid() if self.use_expert_bias: - scores_for_routing = routing_weights + self.expert_bias + scores_for_routing = routing_weights + expert_bias _, selected_experts = torch.topk(scores_for_routing, k=self.top_k, dim=-1) routing_weights = torch.gather(routing_weights, dim=1, index=selected_experts).type_as(router_logits) else: @@ -99,13 +94,21 @@ def route_tokens_to_experts(self, router_logits): if self.norm_topk_prob: routing_weights = routing_weights / (routing_weights.sum(dim=-1, keepdim=True) + 1e-6) routing_weights = routing_weights * self.routed_scaling_factor - return selected_experts, routing_weights + return router_logits, routing_weights, selected_experts + + +class Lfm2MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def __init__(self, config): + super().__init__(config) + self.use_expert_bias = config.use_expert_bias + if self.use_expert_bias: + self.register_buffer("expert_bias", torch.zeros(config.num_experts, dtype=torch.float32)) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(router_logits) + expert_bias = self.expert_bias if self.use_expert_bias else None + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped, expert_bias) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -137,6 +140,8 @@ def _init_weights(self, module): if isinstance(module, Lfm2MoeExperts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Lfm2MoeTopKRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, Lfm2MoeSparseMoeBlock): if module.use_expert_bias: init.zeros_(module.expert_bias) diff --git a/src/transformers/models/longcat_flash/configuration_longcat_flash.py b/src/transformers/models/longcat_flash/configuration_longcat_flash.py index 39e5a03338d8..a7330d73cc7b 100644 --- a/src/transformers/models/longcat_flash/configuration_longcat_flash.py +++ b/src/transformers/models/longcat_flash/configuration_longcat_flash.py @@ -52,6 +52,12 @@ class LongcatFlashConfig(PreTrainedConfig): model_type = "longcat_flash" keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_experts": "n_routed_experts", + "num_local_experts": "n_routed_experts", + "num_experts_per_tok": "moe_topk", + "intermediate_size": "ffn_hidden_size", + } default_theta = 10000000.0 base_model_tp_plan = { "layers.*.self_attn.*.q_b_proj": "colwise", diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index 4a287fde6e02..ffd6252562cb 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -131,11 +131,11 @@ def forward(self, x, position_ids): class LongcatFlashMLP(nn.Module): - def __init__(self, config, hidden_size=None, intermediate_size=None): + def __init__(self, config, intermediate_size=None): super().__init__() self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) @@ -149,17 +149,17 @@ def forward(self, x): class LongcatFlashTopkRouter(nn.Module): def __init__(self, config): super().__init__() - self.config = config + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size self.n_routed_experts = config.n_routed_experts + (config.zero_expert_num or 0) - self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) - - self.top_k = config.moe_topk self.routed_scaling_factor = config.routed_scaling_factor self.router_bias = getattr(config, "router_bias", False) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) self.classifier = nn.Linear(config.hidden_size, self.n_routed_experts, bias=self.router_bias) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.classifier.weight.type(torch.float32)) scores = router_logits.softmax(dim=-1) topk_indices = self.get_topk_indices(scores) diff --git a/src/transformers/models/longcat_flash/modular_longcat_flash.py b/src/transformers/models/longcat_flash/modular_longcat_flash.py index 205802f45690..51f25a1d630d 100644 --- a/src/transformers/models/longcat_flash/modular_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modular_longcat_flash.py @@ -36,10 +36,10 @@ DeepseekV3Model, DeepseekV3RMSNorm, DeepseekV3RotaryEmbedding, - DeepseekV3TopkRouter, apply_rotary_pos_emb_interleave, eager_attention_forward, ) +from ..mixtral.modeling_mixtral import MixtralTopKRouter from .configuration_longcat_flash import LongcatFlashConfig @@ -54,28 +54,18 @@ class LongcatFlashRotaryEmbedding(DeepseekV3RotaryEmbedding): pass -# TODO remap config key ffn_hidden_size -> intermediate_size class LongcatFlashMLP(DeepseekV3MLP): - def __init__(self, config, hidden_size=None, intermediate_size=None): - super().__init__(config) - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size - self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size + pass -# TODO remap config key moe_topk -> num_experts_per_tok -class LongcatFlashTopkRouter(DeepseekV3TopkRouter): +class LongcatFlashTopkRouter(MixtralTopKRouter): def __init__(self, config): super().__init__(config) - del self.n_group - del self.topk_group - del self.weight - del self.norm_topk_prob - - self.top_k = config.moe_topk + del self.weight # longcat routes through `classifier` (an nn.Linear) instead of a weight Parameter self.n_routed_experts = config.n_routed_experts + (config.zero_expert_num or 0) self.routed_scaling_factor = config.routed_scaling_factor - self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) self.router_bias = getattr(config, "router_bias", False) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) self.classifier = nn.Linear(config.hidden_size, self.n_routed_experts, bias=self.router_bias) @torch.no_grad() @@ -85,7 +75,7 @@ def get_topk_indices(self, scores): return topk_indices def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.classifier.weight.type(torch.float32)) scores = router_logits.softmax(dim=-1) topk_indices = self.get_topk_indices(scores) diff --git a/src/transformers/models/minimax/configuration_minimax.py b/src/transformers/models/minimax/configuration_minimax.py index 9a3a0023725e..71a3e2b4b79c 100644 --- a/src/transformers/models/minimax/configuration_minimax.py +++ b/src/transformers/models/minimax/configuration_minimax.py @@ -75,6 +75,12 @@ class MiniMaxConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = {"num_experts": "num_local_experts"} vocab_size: int = 32000 diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 0bd400458129..13896dd21cc9 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -101,6 +101,12 @@ class MiniMaxConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = {"num_experts": "num_local_experts"} vocab_size: int = 32000 diff --git a/src/transformers/models/minimax_m2/configuration_minimax_m2.py b/src/transformers/models/minimax_m2/configuration_minimax_m2.py index 75acb8b755d7..d6647fb5feba 100644 --- a/src/transformers/models/minimax_m2/configuration_minimax_m2.py +++ b/src/transformers/models/minimax_m2/configuration_minimax_m2.py @@ -61,6 +61,12 @@ class MiniMaxM2Config(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_experts": "num_local_experts", } diff --git a/src/transformers/models/minimax_m2/modular_minimax_m2.py b/src/transformers/models/minimax_m2/modular_minimax_m2.py index e71774ef59fd..c66f9fedc580 100644 --- a/src/transformers/models/minimax_m2/modular_minimax_m2.py +++ b/src/transformers/models/minimax_m2/modular_minimax_m2.py @@ -80,6 +80,12 @@ class MiniMaxM2Config(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_experts": "num_local_experts", } diff --git a/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py index 86c0009389be..b785c6bb8a31 100644 --- a/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py +++ b/src/transformers/models/minimax_m3_vl/configuration_minimax_m3_vl.py @@ -69,6 +69,12 @@ class MiniMaxM3VLTextConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_experts": "num_local_experts", } @@ -98,14 +104,6 @@ class MiniMaxM3VLTextConfig(PreTrainedConfig): router_jitter_noise: float = 0.0 rope_parameters: RopeParameters | dict | None = None base_config_key = "text_config" - base_model_ep_plan = { - "layers.*.mlp.gate": "ep_router", - "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", - "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm", - "layers.*.mlp.experts.down_proj": "grouped_gemm", - "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm", - "layers.*.mlp.experts": "moe_tp_experts", - } dense_intermediate_size: int = 12288 shared_intermediate_size: int = 3072 routed_scaling_factor: float = 2.0 diff --git a/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py b/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py index 87d1589a501b..542c1d20aaa2 100644 --- a/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py +++ b/src/transformers/models/minimax_m3_vl/modular_minimax_m3_vl.py @@ -93,14 +93,6 @@ class MiniMaxM3VLTextConfig(MiniMaxM2Config): model_type = "minimax_m3_vl_text" base_config_key = "text_config" - base_model_ep_plan = { - "layers.*.mlp.gate": "ep_router", - "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", - "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm", - "layers.*.mlp.experts.down_proj": "grouped_gemm", - "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm", - "layers.*.mlp.experts": "moe_tp_experts", - } hidden_size: int = 6144 intermediate_size: int = 3072 diff --git a/src/transformers/models/mistral4/configuration_mistral4.py b/src/transformers/models/mistral4/configuration_mistral4.py index 0e16e0a14f45..61369a5d8915 100644 --- a/src/transformers/models/mistral4/configuration_mistral4.py +++ b/src/transformers/models/mistral4/configuration_mistral4.py @@ -62,8 +62,15 @@ class Mistral4Config(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 131072 diff --git a/src/transformers/models/mistral4/modeling_mistral4.py b/src/transformers/models/mistral4/modeling_mistral4.py index 23bfe1091b69..bddfe6c8a232 100644 --- a/src/transformers/models/mistral4/modeling_mistral4.py +++ b/src/transformers/models/mistral4/modeling_mistral4.py @@ -151,19 +151,42 @@ def forward(self, x): class Mistral4TopkRouter(nn.Module): def __init__(self, config): super().__init__() - self.config = config - self.n_routed_experts = config.n_routed_experts - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + self.routed_scaling_factor = config.routed_scaling_factor + self.num_group = config.n_group + self.topk_group = config.topk_group def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight) - return router_logits + scores = router_logits.softmax(-1) + group_scores = ( + scores.view(-1, self.num_group, self.num_experts // self.num_group).topk(2, dim=-1)[0].sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_experts_implementation -class Mistral4NaiveMoe(nn.Module): +class Mistral4Experts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -207,48 +230,19 @@ class Mistral4MoE(nn.Module): A mixed expert module containing shared experts. """ - def __init__(self, config): + def __init__(self, config: Mistral4Config): super().__init__() self.config = config - self.experts = Mistral4NaiveMoe(config) + self.experts = Mistral4Experts(config) self.gate = Mistral4TopkRouter(config) self.shared_experts = Mistral4MLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - router_logits = router_logits.softmax(-1) - group_scores = ( - router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -559,7 +553,7 @@ def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Mistral4TopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - elif isinstance(module, Mistral4NaiveMoe): + elif isinstance(module, Mistral4Experts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/mistral4/modular_mistral4.py b/src/transformers/models/mistral4/modular_mistral4.py index acd9f1f60191..21fb0ddfe449 100644 --- a/src/transformers/models/mistral4/modular_mistral4.py +++ b/src/transformers/models/mistral4/modular_mistral4.py @@ -25,11 +25,12 @@ from ...processing_utils import Unpack from ...utils import logging from ...utils.generic import is_flash_attention_requested +from ..deepseek_v2.modeling_deepseek_v2 import DeepseekV2TopkRouter from ..deepseek_v3.modeling_deepseek_v3 import ( DeepseekV3Attention, DeepseekV3DecoderLayer, + DeepseekV3Experts, DeepseekV3MoE, - DeepseekV3NaiveMoe, apply_rotary_pos_emb_interleave, ) from ..llama.modeling_llama import ( @@ -60,46 +61,43 @@ class Mistral4MLP(Qwen2MoeMLP): pass -class Mistral4TopkRouter(nn.Module): +class Mistral4TopkRouter(DeepseekV2TopkRouter): def __init__(self, config): - super().__init__() - self.config = config - self.n_routed_experts = config.n_routed_experts - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + super().__init__(config) + del self.topk_method + self.norm_topk_prob = config.norm_topk_prob def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight) - return router_logits - - -class Mistral4NaiveMoe(DeepseekV3NaiveMoe): - pass - - -class Mistral4MoE(DeepseekV3MoE): - def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - router_logits = router_logits.softmax(-1) + scores = router_logits.softmax(-1) group_scores = ( - router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1) + scores.view(-1, self.num_group, self.num_experts // self.num_group).topk(2, dim=-1)[0].sum(dim=-1) ) group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] group_mask = torch.zeros_like(group_scores) group_mask.scatter_(1, group_idx, 1) score_mask = ( group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) ) - scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0) + scores_for_choice = scores.masked_fill(~score_mask.bool(), 0.0) topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) + topk_weights = scores.gather(1, topk_indices) if self.norm_topk_prob: denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 topk_weights /= denominator topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights + return router_logits, topk_weights, topk_indices + + +class Mistral4Experts(DeepseekV3Experts): + pass + + +class Mistral4MoE(DeepseekV3MoE): + pass class Mistral4Attention(DeepseekV3Attention): @@ -259,7 +257,7 @@ def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Mistral4TopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - elif isinstance(module, Mistral4NaiveMoe): + elif isinstance(module, Mistral4Experts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index 240f24411031..707e98098357 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -56,6 +56,12 @@ class MixtralConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = {"num_experts": "num_local_experts"} vocab_size: int = 32000 diff --git a/src/transformers/models/nemotron_h/configuration_nemotron_h.py b/src/transformers/models/nemotron_h/configuration_nemotron_h.py index 1c39584c45ac..8df72751376a 100644 --- a/src/transformers/models/nemotron_h/configuration_nemotron_h.py +++ b/src/transformers/models/nemotron_h/configuration_nemotron_h.py @@ -81,7 +81,7 @@ class NemotronHConfig(PreTrainedConfig): """ model_type = "nemotron_h" - attribute_map = {"layer_types": "layers_block_type"} + attribute_map = {"layer_types": "layers_block_type", "num_experts": "n_routed_experts"} keys_to_ignore_at_inference = ["past_key_values"] vocab_size: int = 131072 diff --git a/src/transformers/models/nemotron_h/modeling_nemotron_h.py b/src/transformers/models/nemotron_h/modeling_nemotron_h.py index 1327405217e0..1cd607e8919b 100644 --- a/src/transformers/models/nemotron_h/modeling_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modeling_nemotron_h.py @@ -694,12 +694,6 @@ def __init__(self, config, layer_idx: int | None = None): # Override shared_experts to use NemotronHMLP with correct intermediate size self.shared_experts = NemotronHMLP(config=config, intermediate_size=config.moe_shared_expert_intermediate_size) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok # NemotronH-specific latent projection layers if config.moe_latent_size is not None: @@ -709,36 +703,10 @@ def __init__(self, config, layer_idx: int | None = None): self.fc1_latent_proj = nn.Identity() self.fc2_latent_proj = nn.Identity() - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # NemotronH-specific: latent projection @@ -754,16 +722,42 @@ def forward(self, hidden_states): class NemotronHTopkRouter(nn.Module): def __init__(self, config): super().__init__() - self.config = config - self.n_routed_experts = config.n_routed_experts - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + self.routed_scaling_factor = config.routed_scaling_factor + self.num_group = config.n_group + self.topk_group = config.topk_group + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices def rotate_half(x): diff --git a/src/transformers/models/nemotron_h/modular_nemotron_h.py b/src/transformers/models/nemotron_h/modular_nemotron_h.py index c3f6f3b91acd..1c1d1483439b 100644 --- a/src/transformers/models/nemotron_h/modular_nemotron_h.py +++ b/src/transformers/models/nemotron_h/modular_nemotron_h.py @@ -209,8 +209,7 @@ def __init__(self, config, layer_idx: int | None = None): def forward(self, hidden_states): residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # NemotronH-specific: latent projection diff --git a/src/transformers/models/olmoe/configuration_olmoe.py b/src/transformers/models/olmoe/configuration_olmoe.py index 16bedbe698f8..1bdb43a14332 100644 --- a/src/transformers/models/olmoe/configuration_olmoe.py +++ b/src/transformers/models/olmoe/configuration_olmoe.py @@ -54,6 +54,12 @@ class OlmoeConfig(PreTrainedConfig): "layers.*.mlp.experts.down_proj": "rowwise", "layers.*.mlp.experts": "moe_tp_experts", } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 50304 hidden_size: int = 2048 diff --git a/src/transformers/models/phimoe/configuration_phimoe.py b/src/transformers/models/phimoe/configuration_phimoe.py index e20f94085be0..8a983774e534 100644 --- a/src/transformers/models/phimoe/configuration_phimoe.py +++ b/src/transformers/models/phimoe/configuration_phimoe.py @@ -46,6 +46,12 @@ class PhimoeConfig(PreTrainedConfig): model_type = "phimoe" keys_to_ignore_at_inference = ["past_key_values"] + base_model_ep_plan = { + "layers.*.mlp.router": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } default_theta = 1000000.0 vocab_size: int = 32064 diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 23bc944c522a..eda3f270d5e4 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -490,6 +490,7 @@ def __init__(self, config: PhimoeConfig): self.router_jitter_noise = config.router_jitter_noise self.input_jitter_noise = config.input_jitter_noise self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training and self.input_jitter_noise > 0: diff --git a/src/transformers/models/phimoe/modular_phimoe.py b/src/transformers/models/phimoe/modular_phimoe.py index 4ab1b447a03a..a4c0a081d89c 100644 --- a/src/transformers/models/phimoe/modular_phimoe.py +++ b/src/transformers/models/phimoe/modular_phimoe.py @@ -281,6 +281,7 @@ def __init__(self, config: PhimoeConfig): self.router_jitter_noise = config.router_jitter_noise self.input_jitter_noise = config.input_jitter_noise self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training and self.input_jitter_noise > 0: diff --git a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py index 7f961976b58c..318eedbe203d 100644 --- a/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/configuration_qwen2_moe.py @@ -67,6 +67,12 @@ class Qwen2MoeConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 151936 hidden_size: int = 2048 diff --git a/src/transformers/models/qwen3_5/configuration_qwen3_5.py b/src/transformers/models/qwen3_5/configuration_qwen3_5.py index 13551c377991..95c7165f9106 100644 --- a/src/transformers/models/qwen3_5/configuration_qwen3_5.py +++ b/src/transformers/models/qwen3_5/configuration_qwen3_5.py @@ -72,6 +72,12 @@ class Qwen3_5TextConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 248320 hidden_size: int = 4096 diff --git a/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py b/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py index 4a32828987ed..1eb16ae5db4b 100644 --- a/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py +++ b/src/transformers/models/qwen3_5_moe/configuration_qwen3_5_moe.py @@ -75,6 +75,12 @@ class Qwen3_5MoeTextConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 248320 hidden_size: int = 2048 diff --git a/src/transformers/models/qwen3_next/configuration_qwen3_next.py b/src/transformers/models/qwen3_next/configuration_qwen3_next.py index bf26179ff3fd..3ad4b724cbc7 100644 --- a/src/transformers/models/qwen3_next/configuration_qwen3_next.py +++ b/src/transformers/models/qwen3_next/configuration_qwen3_next.py @@ -80,6 +80,12 @@ class Qwen3NextConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } vocab_size: int = 151936 hidden_size: int = 2048 diff --git a/src/transformers/models/solar_open/configuration_solar_open.py b/src/transformers/models/solar_open/configuration_solar_open.py index ac0016aa7791..fb0f529dc4d8 100644 --- a/src/transformers/models/solar_open/configuration_solar_open.py +++ b/src/transformers/models/solar_open/configuration_solar_open.py @@ -51,8 +51,15 @@ class SolarOpenConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = { "num_local_experts": "n_routed_experts", + "num_experts": "n_routed_experts", } vocab_size: int = 196608 diff --git a/src/transformers/models/solar_open/modeling_solar_open.py b/src/transformers/models/solar_open/modeling_solar_open.py index 0eb50021ecd6..f282fefa1a64 100644 --- a/src/transformers/models/solar_open/modeling_solar_open.py +++ b/src/transformers/models/solar_open/modeling_solar_open.py @@ -105,27 +105,48 @@ def forward(self, x): class SolarOpenTopkRouter(nn.Module): - def __init__(self, config: SolarOpenConfig): + def __init__(self, config): super().__init__() - self.config = config self.top_k = config.num_experts_per_tok - self.n_routed_experts = config.n_routed_experts + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) self.routed_scaling_factor = config.routed_scaling_factor - self.n_group = config.n_group + self.num_group = config.n_group self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32)) + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_dim) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits + scores = router_logits.sigmoid() + scores_for_choice = scores + self.e_score_correction_bias + group_scores = ( + scores_for_choice.view(-1, self.num_group, self.num_experts // self.num_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.num_group, self.num_experts // self.num_group) + .reshape(-1, self.num_experts) + ) + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return router_logits, topk_weights, topk_indices @use_experts_implementation -class SolarOpenNaiveMoe(nn.Module): +class SolarOpenExperts(nn.Module): """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): @@ -169,51 +190,19 @@ class SolarOpenMoE(nn.Module): A mixed expert module containing shared experts. """ - def __init__(self, config): + def __init__(self, config: SolarOpenConfig): super().__init__() self.config = config - self.experts = SolarOpenNaiveMoe(config) + self.experts = SolarOpenExperts(config) self.gate = SolarOpenTopkRouter(config) self.shared_experts = SolarOpenMLP( config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + _, topk_weights, topk_indices = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) @@ -403,7 +392,7 @@ def _init_weights(self, module): if isinstance(module, SolarOpenTopkRouter): init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) init.zeros_(module.e_score_correction_bias) - elif isinstance(module, SolarOpenNaiveMoe): + elif isinstance(module, SolarOpenExperts): init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) diff --git a/src/transformers/models/youtu/configuration_youtu.py b/src/transformers/models/youtu/configuration_youtu.py index 6d9f2cef1f96..ffe0fa4fc355 100644 --- a/src/transformers/models/youtu/configuration_youtu.py +++ b/src/transformers/models/youtu/configuration_youtu.py @@ -60,6 +60,12 @@ class YoutuConfig(PreTrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + base_model_ep_plan = { + "layers.*.mlp.gate": "ep_router", + "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", + "layers.*.mlp.experts.down_proj": "grouped_gemm", + "layers.*.mlp.experts": "moe_tp_experts", + } attribute_map = {} vocab_size: int = 128256 diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 011f45e6a1a5..c205e6185f97 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -134,6 +134,27 @@ def _process_model_before_weight_loading( pre_quantized=self.pre_quantized, ) + def _process_model_after_weight_loading(self, model, **kwargs): + # dsv4-flash-base stores its (power-of-two) ue8m0 scales in a float32 container under + # `.scale`; those renamed keys keep the on-disk float32 dtype, so cast them to the UE8M0 + # dtype the kernels expect (exact, since the values are powers of two). Checkpoints that + # already ship the native float8 E8M0 dtype (e.g. dsv4-flash) are left untouched. + if self.quantization_config.scale_fmt == "ue8m0": + from ..integrations.finegrained_fp8 import _get_ue8m0_dtype + + ue8m0 = _get_ue8m0_dtype() + float32_scales = [ + name + for name, param in model.named_parameters() + if name.endswith("_scale_inv") and param.dtype == torch.float32 + ] + for name in float32_scales: + module_name, _, attr = name.rpartition(".") + module = model.get_submodule(module_name) + scale = getattr(module, attr) + setattr(module, attr, torch.nn.Parameter(scale.data.to(ue8m0), requires_grad=False)) + return model + def update_tp_plan(self, config): if "Qwen3" in config.__class__.__name__: text_plan = { @@ -204,41 +225,6 @@ def get_weight_conversions(self): ] return [] - def _is_mxfp8(self) -> bool: - """MXFP8 checkpoints ship E8M0 (uint8) per-block scales; plain FP8 ships float32.""" - quant_method = getattr(self.quantization_config, "quant_method", None) - return quant_method == "mxfp8" - - def _update_weight_conversions_mxfp8(self, weight_conversions): - """ - Native MXFP8 path: prepend a `Fp8DecodeScale` op so the uint8 E8M0 - scales are decoded to float32 `2 ** (byte - 127)` *before* any merge/concat op - and add a generic fallback converter that decodes the scales of plain `FP8Linear` weights (attention / dense projections) - which have no model-specific converter. - """ - from ..core_model_loading import WeightConverter - from ..integrations.finegrained_fp8 import Fp8DecodeScale - - updated: list = [] - for conv in weight_conversions: - if isinstance(conv, WeightConverter) and any(p.endswith(".weight") for p in conv.source_patterns): - conv = WeightConverter( - source_patterns=conv.source_patterns, - target_patterns=conv._original_target_patterns, - operations=[Fp8DecodeScale(self)] + list(conv.operations), - ) - updated.append(conv) - # Generic fallback for plain ``nn.Linear`` scales with no model-specific converter. - # Listed last so the model converters above win the first-match for expert/dense scales. - updated.append( - WeightConverter( - source_patterns=["weight_scale_inv"], - target_patterns="weight_scale_inv", - operations=[Fp8DecodeScale(self)], - ) - ) - return updated - def update_weight_conversions(self, weight_conversions): """When loading with ``dequantize=True``, attach an :class:`Fp8Dequantize` op to every existing :class:`WeightConverter` so that per-block scales are folded into @@ -270,9 +256,6 @@ def update_weight_conversions(self, weight_conversions): weight_conversions = [scale_rename] + list(weight_conversions) if not (self.pre_quantized and self.quantization_config.dequantize): - if self.pre_quantized and self._is_mxfp8(): - # mxfp8 needs a pre-processing on the scales when not dequantizing - return self._update_weight_conversions_mxfp8(weight_conversions) return weight_conversions + self.get_weight_conversions() updated: list = [] diff --git a/tests/test_tensor_parallel_mixin.py b/tests/test_tensor_parallel_mixin.py index 547bce7dacc4..a40649ad883c 100644 --- a/tests/test_tensor_parallel_mixin.py +++ b/tests/test_tensor_parallel_mixin.py @@ -454,54 +454,38 @@ def _get_tp_model_class(self): return self.model_tester.causal_lm_class return self.all_model_classes[0] - def _skip_if_not_supported(self): - """Check and skip test if TP is not supported for this model/environment.""" - if not is_torch_greater_or_equal("2.9"): - self.skipTest("Tensor parallel tests require torch >= 2.9") - - if torch.cuda.is_available() or torch.xpu.is_available(): - self.skipTest("Tensor parallel mixin tests are CPU-only and should not run on GPU or XPU machines") - - if os.cpu_count() < self.tensor_parallel_size: - self.skipTest( - f"Tensor parallel tests require at least {self.tensor_parallel_size} CPUs, " - f"but only {os.cpu_count()} available" + def _skip_if_not_supported(self, expert_parallel: bool = False): + """Check and skip the test if tensor/expert parallel is not supported for this model/environment.""" + parallelism = "Expert" if expert_parallel else "Tensor" + # An EP-capable MoE (@use_experts_implementation) must ship an ep_plan; assert before any + # skip so a plan-less model fails even where the parallel test can't run (GPU, old torch). + if expert_parallel and self._get_tp_model_class()._can_set_experts_implementation(): + self.assertTrue( + self._has_ep_plan(), + "Model supports a switchable experts implementation (@use_experts_implementation) but defines no " + "base_model_ep_plan; add an expert-parallel plan to its config so the EP path is covered.", ) - if not hasattr(self.model_tester, "causal_lm_class") or self.model_tester.causal_lm_class is None: - self.skipTest("Model tester does not have causal_lm_class (not using CausalLMModelTester)") - - if not self._has_tp_plan(): - self.skipTest("Model does not have a tensor parallel plan (base_model_tp_plan)") - - # # Skip encoder-decoder models (TP not supported) - # if getattr(self, "is_encoder_decoder", False): - # self.skipTest("TP tests not supported for encoder-decoder models") - - # # Skip VLM models for now - # config = self.model_tester.get_config() - # if hasattr(config, "vision_config") and config.vision_config is not None: - # self.skipTest("VLM models are not yet supported in TP tests") - - def _skip_if_ep_not_supported(self): - """Check and skip test if EP is not supported for this model/environment.""" if not is_torch_greater_or_equal("2.9"): - self.skipTest("Expert parallel tests require torch >= 2.9") + self.skipTest(f"{parallelism} parallel tests require torch >= 2.9") if torch.cuda.is_available() or torch.xpu.is_available(): - self.skipTest("Expert parallel mixin tests are CPU-only and should not run on GPU or XPU machines") + self.skipTest(f"{parallelism} parallel mixin tests are CPU-only and should not run on GPU or XPU machines") if os.cpu_count() < self.tensor_parallel_size: self.skipTest( - f"Expert parallel tests require at least {self.tensor_parallel_size} CPUs, " + f"{parallelism} parallel tests require at least {self.tensor_parallel_size} CPUs, " f"but only {os.cpu_count()} available" ) if not hasattr(self.model_tester, "causal_lm_class") or self.model_tester.causal_lm_class is None: self.skipTest("Model tester does not have causal_lm_class (not using CausalLMModelTester)") - if not self._has_ep_plan(): - self.skipTest("Model does not have an expert parallel plan (base_model_ep_plan)") + if expert_parallel: + if not self._has_ep_plan(): + self.skipTest("Model does not have an expert parallel plan (base_model_ep_plan)") + elif not self._has_tp_plan(): + self.skipTest("Model does not have a tensor parallel plan (base_model_tp_plan)") @is_tensor_parallel_test def test_tp_forward(self): @@ -577,7 +561,7 @@ def test_tp_generation_quantized(self): @is_tensor_parallel_test def test_ep_forward(self): - self._skip_if_ep_not_supported() + self._skip_if_not_supported(expert_parallel=True) config = self.model_tester.get_config() model_class = self._get_tp_model_class() @@ -593,7 +577,7 @@ def test_ep_forward(self): @is_tensor_parallel_test def test_ep_backward(self): - self._skip_if_ep_not_supported() + self._skip_if_not_supported(expert_parallel=True) config = self.model_tester.get_config() model_class = self._get_tp_model_class()