diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 61190f7837e5..d87a52b6dbb7 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -501,6 +501,8 @@ title: Apertus - local: model_doc/arcee title: Arcee + - local: model_doc/bailing2_5_moe + title: BailingMoeV2_5 - local: model_doc/bamba title: Bamba - local: model_doc/bart diff --git a/docs/source/en/model_doc/bailing2_5_moe.md b/docs/source/en/model_doc/bailing2_5_moe.md new file mode 100644 index 000000000000..50d8497e550e --- /dev/null +++ b/docs/source/en/model_doc/bailing2_5_moe.md @@ -0,0 +1,72 @@ + +*This model was contributed to Hugging Face Transformers on 2026-06-23.* + +# BailingMoeV2_5 + +## Overview + +The BailingMoeV2_5 model (Ling/Ring 2.6 series, e.g. Ling-2.6-flash) was proposed by [InclusionAI](https://huggingface.co/inclusionAI). It is based on a hybrid linear attention architecture, combining Multi-head Latent Attention (MLA), Lightning Linear Attention, and Mixture of Experts (MoE). + +Key architectural features: +- **Hybrid Attention**: Uses a 1:7 ratio of MLA to Lightning Linear Attention layers, achieving near-linear computational complexity +- **Multi-head Latent Attention (MLA)**: Similar to DeepSeek-V3, with compressed KV cache via LoRA projections +- **Lightning Linear Attention**: Based on SimpleGLA (Simple Gated Linear Attention) from the flash-linear-attention library +- **Mixture of Experts**: 256 routed experts with 8 active per token, plus shared experts + +### Usage tips + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +model = AutoModelForCausalLM.from_pretrained( + "inclusionAI/Ling-2.6-flash-base", + device_map="auto", + dtype=torch.bfloat16, +) +tokenizer = AutoTokenizer.from_pretrained("inclusionAI/Ling-2.6-flash-base") + +inputs = tokenizer("Hello, how are you?", return_tensors="pt").to(model.device) +outputs = model.generate(**inputs, max_new_tokens=50) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +For optimal performance with the linear attention layers, install the [flash-linear-attention](https://github.com/fla-org/flash-linear-attention) library. Without it, the model falls back to a pure PyTorch implementation. + +## BailingMoeV2_5Config + +[[autodoc]] BailingMoeV2_5Config + +## BailingMoeV2_5Model + +[[autodoc]] BailingMoeV2_5Model + - forward + +## BailingMoeV2_5ForCausalLM + +[[autodoc]] BailingMoeV2_5ForCausalLM + - forward + +## BailingMoeV2_5ForSequenceClassification + +[[autodoc]] BailingMoeV2_5ForSequenceClassification + - forward + +## BailingMoeV2_5ForTokenClassification + +[[autodoc]] BailingMoeV2_5ForTokenClassification + - forward diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index a0cff222273a..28f8110f9594 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -868,6 +868,31 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"\.self_attn\.norm_q\.", target_patterns=".self_attn.q_norm."), WeightRenaming(source_patterns=r"\.self_attn\.norm_k\.", target_patterns=".self_attn.k_norm."), ], + "bailing2_5_moe": [ + # Embedding rename. + WeightRenaming(r"word_embeddings", "embed_tokens"), + # NOTE: full-attention (MLA) layer indices (where (i + 1) % layer_group_size == 0) + # are injected dynamically in `extract_weight_conversions_for_model` based on the + # model config, so the mapping works for any num_hidden_layers / layer_group_size. + WeightRenaming(r"\.attention\.", ".linear_attn."), + WeightRenaming(r"\.dense\.weight", ".o_proj.weight"), + # MoE router bias rename. + WeightRenaming(r"mlp\.gate\.expert_bias", "mlp.gate.e_score_correction_bias"), + # Pack per-expert gate_proj and up_proj into a single 3D tensor. + WeightConverter( + source_patterns=[ + "mlp.experts.*.gate_proj.weight", + "mlp.experts.*.up_proj.weight", + ], + target_patterns="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_patterns="mlp.experts.*.down_proj.weight", + target_patterns="mlp.experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + ], "phimoe": [ WeightRenaming(".block_sparse_moe.", ".mlp."), WeightRenaming(".gate.weight", ".router.weight"), @@ -1516,6 +1541,20 @@ def extract_weight_conversions_for_model( conversions = get_checkpoint_conversion_mapping(class_name) if conversions is None and model_type: conversions = get_checkpoint_conversion_mapping(model_type) + + if model_type == "bailing2_5_moe" and conversions is not None: + # Inject `attention -> self_attn` renames for full-attention layer indices, + # derived from the model config rather than hardcoded. + num_hidden_layers = getattr(model.config, "num_hidden_layers", 0) + layer_group_size = getattr(model.config, "layer_group_size", 0) or 0 + if layer_group_size > 0: + full_attn_layers = [i for i in range(num_hidden_layers) if (i + 1) % layer_group_size == 0] + self_attn_renames = [ + WeightRenaming(rf"layers\.{i}\.attention\.", f"layers.{i}.self_attn.") for i in full_attn_layers + ] + # These must run before the generic `.attention. -> .linear_attn.` rule. + conversions = self_attn_renames + conversions + return conversions diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 025c2ac7c0e1..eef556a10ae3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -31,6 +31,7 @@ from .auto import * from .autoformer import * from .aya_vision import * + from .bailing2_5_moe import * from .bamba import * from .bark import * from .bart import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index fe1212f230c1..58b6e4f8c13a 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -43,6 +43,7 @@ ("audioflamingo3_encoder", "AudioFlamingo3EncoderConfig"), ("autoformer", "AutoformerConfig"), ("aya_vision", "AyaVisionConfig"), + ("bailing2_5_moe", "BailingMoeV2_5Config"), ("bamba", "BambaConfig"), ("bark", "BarkConfig"), ("bart", "BartConfig"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 9dfc42d07cd0..0a4babad192a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -56,6 +56,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("audioflamingo3_encoder", "AudioFlamingo3Encoder"), ("autoformer", "AutoformerModel"), ("aya_vision", "AyaVisionModel"), + ("bailing2_5_moe", "BailingMoeV2_5Model"), ("bamba", "BambaModel"), ("bark", "BarkModel"), ("bart", "BartModel"), @@ -647,6 +648,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("apertus", "ApertusForCausalLM"), ("arcee", "ArceeForCausalLM"), ("aria_text", "AriaTextForCausalLM"), + ("bailing2_5_moe", "BailingMoeV2_5ForCausalLM"), ("bamba", "BambaForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), @@ -1301,6 +1303,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): # Model for Sequence Classification mapping ("albert", "AlbertForSequenceClassification"), ("arcee", "ArceeForSequenceClassification"), + ("bailing2_5_moe", "BailingMoeV2_5ForSequenceClassification"), ("bart", "BartForSequenceClassification"), ("bert", "BertForSequenceClassification"), ("big_bird", "BigBirdForSequenceClassification"), @@ -1533,6 +1536,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("albert", "AlbertForTokenClassification"), ("apertus", "ApertusForTokenClassification"), ("arcee", "ArceeForTokenClassification"), + ("bailing2_5_moe", "BailingMoeV2_5ForTokenClassification"), ("bert", "BertForTokenClassification"), ("big_bird", "BigBirdForTokenClassification"), ("biogpt", "BioGptForTokenClassification"), diff --git a/src/transformers/models/bailing2_5_moe/__init__.py b/src/transformers/models/bailing2_5_moe/__init__.py new file mode 100644 index 000000000000..e5b1b784b514 --- /dev/null +++ b/src/transformers/models/bailing2_5_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_bailing2_5_moe import * + from .modeling_bailing2_5_moe import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/bailing2_5_moe/configuration_bailing2_5_moe.py b/src/transformers/models/bailing2_5_moe/configuration_bailing2_5_moe.py new file mode 100644 index 000000000000..7990f1294e0f --- /dev/null +++ b/src/transformers/models/bailing2_5_moe/configuration_bailing2_5_moe.py @@ -0,0 +1,162 @@ +# Copyright 2025 InclusionAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BailingMoeV2_5 model configuration""" + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="inclusionAI/Ling-2.6-flash-base") +@strict +class BailingMoeV2_5Config(PreTrainedConfig): + r""" + layer_group_size (`int`, *optional*, defaults to 8): + Controls the hybrid layer pattern. Every `layer_group_size`-th layer uses full MLA attention, + while the rest use lightning linear attention. + n_group (`int`, *optional*, defaults to 8): + Number of groups for routed experts in group-limited-greedy routing. + first_k_dense_replace (`int`, *optional*, defaults to 4): + Number of initial dense layers before switching to MoE. + rope_interleave (`bool`, *optional*, defaults to `True`): + Whether to interleave the rotary position embeddings. + group_norm_size (`int`, *optional*, defaults to 8): + Group size for group RMS normalization in linear attention layers. + num_kv_heads_for_linear_attn (`int`, *optional*, defaults to 64): + Number of key-value heads used in linear attention layers. + linear_silu (`bool`, *optional*, defaults to `False`): + Whether to apply SiLU activation on the gate in linear attention. + moe_shared_expert_intermediate_size (`int`, *optional*, defaults to 2048): + Intermediate size of the shared expert in MoE layers. + topk_method (`str`, *optional*, defaults to `"noaux_tc"`): + Method for selecting top-k experts in the MoE layer. + scoring_func (`str`, *optional*, defaults to `"sigmoid"`): + Scoring function for the router in the MoE layer. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Fraction of the head dimension to apply rotary position embeddings in linear attention layers. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + Coefficient for the auxiliary load balancing loss from the router. + + Example: + + ```python + >>> from transformers import BailingMoeV2_5Model, BailingMoeV2_5Config + + >>> # Initializing a BailingMoeV2_5 style configuration + >>> configuration = BailingMoeV2_5Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "bailing2_5_moe" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_local_experts": "num_experts", + } + + vocab_size: int = 157184 + hidden_size: int = 8192 + intermediate_size: int = 18432 + moe_intermediate_size: int = 2048 + moe_shared_expert_intermediate_size: int = 2048 + num_hidden_layers: int = 80 + num_attention_heads: int = 64 + num_key_value_heads: int | None = 64 + num_experts: int = 256 + num_shared_experts: int = 1 + num_experts_per_tok: int | None = 8 + routed_scaling_factor: float = 2.5 + kv_lora_rank: int = 512 + q_lora_rank: int | None = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int | None = 128 + qk_nope_head_dim: int = 128 + n_group: int | None = 8 + topk_group: int | None = 4 + topk_method: str = "noaux_tc" + scoring_func: str = "sigmoid" + first_k_dense_replace: int | None = 4 + norm_topk_prob: bool | None = True + layer_group_size: int = 8 + group_norm_size: int = 8 + num_kv_heads_for_linear_attn: int = 64 + linear_silu: bool = False + hidden_act: str = "silu" + max_position_embeddings: int = 131072 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int | None = 156892 + bos_token_id: int | None = None + eos_token_id: int | list[int] | None = 156892 + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + rope_interleave: bool | None = True + partial_rotary_factor: float = 0.5 + attention_bias: bool = False + attention_dropout: float | int | None = 0.0 + use_qk_norm: bool = True + output_router_logits: bool = False + router_aux_loss_coef: float = 0.001 + layer_types: list[str] | None = None + + def __post_init__(self, **kwargs): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.head_dim = self.qk_rope_head_dim + + if self.layer_types is None: + self.layer_types = [ + "full_attention" if (i + 1) % self.layer_group_size == 0 else "linear_attention" + for i in range(self.num_hidden_layers) + ] + + super().__post_init__(**kwargs) + + def convert_rope_params_to_dict(self, **kwargs): + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or self.rope_parameters + self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} + + self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) + self.standardize_rope_params() + + for key in ["beta_fast", "beta_slow", "factor"]: + if key in self.rope_parameters: + self.rope_parameters[key] = float(self.rope_parameters[key]) + return kwargs + + +__all__ = ["BailingMoeV2_5Config"] diff --git a/src/transformers/models/bailing2_5_moe/modeling_bailing2_5_moe.py b/src/transformers/models/bailing2_5_moe/modeling_bailing2_5_moe.py new file mode 100644 index 000000000000..fd125125fd7c --- /dev/null +++ b/src/transformers/models/bailing2_5_moe/modeling_bailing2_5_moe.py @@ -0,0 +1,1255 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/bailing2_5_moe/modular_bailing2_5_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_bailing2_5_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 InclusionAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Callable + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.import_utils import is_flash_linear_attention_available +from ...utils.output_capturing import OutputRecorder, capture_outputs +from .configuration_bailing2_5_moe import BailingMoeV2_5Config + + +if is_flash_linear_attention_available(): + from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla +else: + chunk_simple_gla, fused_recurrent_simple_gla = None, None + +logger = logging.get_logger(__name__) + + +@use_kernel_forward_from_hub("RMSNorm") +class BailingMoeV2_5RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + BailingMoeV2_5RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class BailingMoeV2_5RotaryEmbedding(nn.Module): + """RoPE for MLA layers — uses full qk_rope_head_dim, interleaved application.""" + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: BailingMoeV2_5Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: BailingMoeV2_5Config | None = None, + device: torch.device | None = None, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = config.qk_rope_head_dim + + attention_factor = 1.0 + + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BailingMoeV2_5LinearRotaryEmbedding(nn.Module): + """RoPE for linear attention layers — uses partial rotary factor on linear attention head dim.""" + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: BailingMoeV2_5Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: BailingMoeV2_5Config | None = None, + device: torch.device | None = None, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + linear_head_dim = config.hidden_size // config.num_kv_heads_for_linear_attn + dim = int(linear_head_dim * config.partial_rotary_factor) + + attention_factor = 1.0 + + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class BailingMoeV2_5MLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class BailingMoeV2_5TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.num_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + return router_logits + + +@use_experts_implementation +class BailingMoeV2_5Experts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class BailingMoeV2_5MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = BailingMoeV2_5Experts(config) + self.gate = BailingMoeV2_5TopkRouter(config) + self.shared_experts = BailingMoeV2_5MLP( + config=config, intermediate_size=config.moe_shared_expert_intermediate_size * config.num_shared_experts + ) + self.n_routed_experts = config.num_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + def route_tokens_to_experts(self, router_logits): + router_logits = router_logits.sigmoid() + router_logits_for_choice = router_logits + self.gate.e_score_correction_bias + group_scores = ( + router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), float("-inf")) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = self.gate(hidden_states) + topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +class BailingMoeV2_5GroupRMSNorm(nn.Module): + """Group-wise RMS normalization for linear attention output.""" + + def __init__(self, hidden_size, group_norm_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_norm_size = group_norm_size + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + orig_shape = hidden_states.shape + # Reshape to groups: (..., group_norm_size, hidden_size // group_norm_size) + hidden_states = hidden_states.view(*orig_shape[:-1], self.group_norm_size, -1) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.view(orig_shape) + return (self.weight * hidden_states).to(input_dtype) + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + # NOTE: attention mask is a 2D boolean tensor + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +is_fast_path_available = all((chunk_simple_gla, fused_recurrent_simple_gla)) + + +def _build_slope_tensor(num_heads: int) -> torch.Tensor: + """Build ALiBi-style slope tensor for lightning linear attention decay.""" + + def _get_interleave(n: int) -> list[float]: + def _get_interleave_power_of_2(n: int) -> list[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + slopes = torch.tensor(_get_interleave(num_heads), dtype=torch.float32) + return slopes + + +def torch_chunk_simple_gla( + query, + key, + value, + g, + chunk_size=64, + initial_state=None, + output_final_state=False, +): + """Pure PyTorch fallback for chunk_simple_gla when fla is not available.""" + initial_dtype = query.dtype + query, key, value = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value)] + g = g.transpose(1, 2).contiguous().to(torch.float32) + + batch_size, num_heads, sequence_length, head_dim = key.shape + v_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_len = sequence_length + pad_size + + scale = head_dim**-0.5 + query = query * scale + + # Reshape into chunks + num_chunks = total_len // chunk_size + query = query.view(batch_size, num_heads, num_chunks, chunk_size, head_dim) + key = key.view(batch_size, num_heads, num_chunks, chunk_size, head_dim) + value = value.view(batch_size, num_heads, num_chunks, chunk_size, v_dim) + g = g.view(batch_size, num_heads, num_chunks, chunk_size) + + # Cumulative decay within each chunk + g_cumsum = g.cumsum(dim=-1) + + state = ( + torch.zeros(batch_size, num_heads, head_dim, v_dim, device=query.device, dtype=torch.float32) + if initial_state is None + else initial_state.to(torch.float32) + ) + output = torch.zeros( + batch_size, num_heads, num_chunks, chunk_size, v_dim, device=query.device, dtype=torch.float32 + ) + + for c in range(num_chunks): + q_c = query[:, :, c] # [B, H, C, D] + k_c = key[:, :, c] + v_c = value[:, :, c] + g_c = g_cumsum[:, :, c] # [B, H, C] + + # Intra-chunk attention with decay + decay_matrix = (g_c.unsqueeze(-1) - g_c.unsqueeze(-2)).tril().exp() + attn = (q_c @ k_c.transpose(-1, -2)) * decay_matrix.tril() + intra = attn @ v_c + + # Inter-chunk: query attends to state from previous chunks + inter = (q_c * g_c.unsqueeze(-1).exp()) @ state + + output[:, :, c] = intra + inter + + # Update state with this chunk's contributions + chunk_end_decay = g_c[:, :, -1].unsqueeze(-1).unsqueeze(-1) + per_step_decay = (g_c[:, :, -1].unsqueeze(-1) - g_c).exp() # [B, H, C] + state = state * chunk_end_decay.exp() + (k_c * per_step_decay.unsqueeze(-1)).transpose(-1, -2) @ v_c + + output = output.view(batch_size, num_heads, total_len, v_dim)[:, :, :sequence_length] + output = output.transpose(1, 2).contiguous().to(initial_dtype) + + if not output_final_state: + state = None + return output, state + + +def torch_recurrent_simple_gla( + query, + key, + value, + g, + initial_state=None, + output_final_state=False, +): + """Pure PyTorch fallback for fused_recurrent_simple_gla.""" + initial_dtype = query.dtype + query, key, value = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value)] + g = g.transpose(1, 2).contiguous().to(torch.float32) + + batch_size, num_heads, sequence_length, head_dim = key.shape + v_dim = value.shape[-1] + scale = head_dim**-0.5 + query = query * scale + + state = ( + torch.zeros(batch_size, num_heads, head_dim, v_dim, device=query.device, dtype=torch.float32) + if initial_state is None + else initial_state.to(torch.float32) + ) + output = torch.zeros(batch_size, num_heads, sequence_length, v_dim, device=query.device, dtype=torch.float32) + + for t in range(sequence_length): + decay = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] + state = state * decay + key[:, :, t].unsqueeze(-1) * value[:, :, t].unsqueeze(-2) + output[:, :, t] = (query[:, :, t].unsqueeze(-1) * state).sum(dim=-2) + + if not output_final_state: + state = None + output = output.transpose(1, 2).contiguous().to(initial_dtype) + return output, state + + +def _apply_rotary_pos_emb_linear(q, k, cos, sin, unsqueeze_dim=2): + """Apply rotary position embedding with partial rotary support for linear attention. + Q/K are in [bsz, seq_len, n_heads, head_dim] format. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class BailingMoeV2_5LightningAttention(nn.Module): + """Lightning Linear Attention using SimpleGLA (Simple Gated Linear Attention) from the fla library.""" + + def __init__(self, config: BailingMoeV2_5Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_kv_heads_for_linear_attn + self.head_dim = config.hidden_size // self.num_heads + + self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) + self.g_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + + self.g_norm = BailingMoeV2_5GroupRMSNorm(config.hidden_size, config.group_norm_size, eps=config.rms_norm_eps) + + if config.use_qk_norm: + self.query_layernorm = BailingMoeV2_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = BailingMoeV2_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # Build ALiBi-style slopes for decay, scaled by layer position + slopes = _build_slope_tensor(self.num_heads) + layer_scale = 1 - (layer_idx - 1) / (config.num_hidden_layers - 1) + 1e-5 + self.register_buffer("slope", (-slopes * layer_scale).to(torch.float32), persistent=False) + + self.chunk_simple_gla = chunk_simple_gla or torch_chunk_simple_gla + self.recurrent_simple_gla = fused_recurrent_simple_gla or torch_recurrent_simple_gla + + if not is_fast_path_available: + logger.warning_once( + "The fast path for BailingMoeV2_5LightningAttention is not available because flash-linear-attention " + "is not installed. Falling back to pure PyTorch implementation. " + "To install, see https://github.com/fla-org/flash-linear-attention#installation" + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + cache_params: Cache | None = None, + attention_mask: torch.Tensor | None = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + bsz, q_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and q_len == 1 + ) + + # Fused QKV projection + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(bsz, q_len, 3, self.num_heads, self.head_dim) + query_states = qkv[:, :, 0] + key_states = qkv[:, :, 1] + value_states = qkv[:, :, 2] + + # Apply QK norm per head before RoPE (matches training-time behaviour) + if self.config.use_qk_norm: + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + # Apply partial RoPE + cos, sin = position_embeddings + query_states, key_states = _apply_rotary_pos_emb_linear(query_states, key_states, cos, sin) + + # Gate projection + g_proj = self.g_proj(hidden_states) + + # Compute decay from slopes + g = self.slope[None, None, :].expand(bsz, q_len, self.num_heads) + + if use_precomputed_states: + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states + attn_output, last_state = self.recurrent_simple_gla( + query_states, + key_states, + value_states, + g=g, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + ) + else: + attn_output, last_state = self.chunk_simple_gla( + query_states, + key_states, + value_states, + g=g, + initial_state=None, + output_final_state=cache_params is not None, + ) + + if cache_params is not None: + # For models without conv1d, we need to ensure dtype/device are set + # on the cache layer before updating recurrent state + layer = cache_params.layers[self.layer_idx] + if not hasattr(layer, "dtype") or layer.dtype is None: + layer.dtype = last_state.dtype + layer.device = last_state.device + cache_params.update_recurrent_state(last_state, self.layer_idx) + # SimpleGLA has no conv1d, so update_conv_state is never called. + # We manually set has_previous_state so decode uses the recurrent path. + layer.has_previous_state = True + + # Reshape from [bsz, q_len, num_heads, head_dim] to [bsz*q_len, hidden_size] + attn_output = attn_output.reshape(bsz * q_len, self.hidden_size) + attn_output = self.g_norm(attn_output) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + if self.config.linear_silu: + attn_output = attn_output * F.silu(g_proj) + else: + attn_output = attn_output * torch.sigmoid(g_proj) + + attn_output = self.o_proj(attn_output) + return attn_output + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + Applies interleaved Rotary Position Embedding to the query and key tensors. + + DeepSeek lays the rotary dimensions out in interleaved pairs `(x0, x1), (x2, x3), ...`, each rotated by a + single frequency. We compute that rotation directly on the even/odd slices instead of de-interleaving with a + `view`/`transpose`/`reshape`; the output is bit-identical to the de-interleaved `rotate_half` formulation while + avoiding the extra contiguous copy. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # `cos`/`sin` are `cat(freqs, freqs)`; the first half holds the per-pair angle. + cos = cos[..., : cos.shape[-1] // 2].unsqueeze(unsqueeze_dim) + sin = sin[..., : sin.shape[-1] // 2].unsqueeze(unsqueeze_dim) + + q1, q2 = q[..., 0::2], q[..., 1::2] + k1, k2 = k[..., 0::2], k[..., 1::2] + + q_embed = torch.cat([q1 * cos - q2 * sin, q2 * cos + q1 * sin], dim=-1) + k_embed = torch.cat([k1 * cos - k2 * sin, k2 * cos + k1 * sin], dim=-1) + return q_embed, k_embed + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class BailingMoeV2_5Attention(nn.Module): + """MLA (Multi-Latent Attention) inherited from DeepSeek V3.""" + + def __init__(self, config: BailingMoeV2_5Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = BailingMoeV2_5RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = BailingMoeV2_5RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + if self.config.rope_parameters.get("rope_type", "default") != "default": + mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_parameters["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class BailingMoeV2_5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BailingMoeV2_5Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + # Token mixer: MLA or Lightning Linear Attention + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "full_attention": + self.self_attn = BailingMoeV2_5Attention(config=config, layer_idx=layer_idx) + else: + self.linear_attn = BailingMoeV2_5LightningAttention(config=config, layer_idx=layer_idx) + + # MLP: MoE or Dense + if layer_idx >= config.first_k_dense_replace: + self.mlp = BailingMoeV2_5MoE(config) + else: + self.mlp = BailingMoeV2_5MLP(config) + + self.input_layernorm = BailingMoeV2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BailingMoeV2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + cache_params=past_key_values, + attention_mask=attention_mask, + ) + elif self.layer_type == "full_attention": + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class BailingMoeV2_5PreTrainedModel(PreTrainedModel): + config: BailingMoeV2_5Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BailingMoeV2_5DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _can_record_outputs = { + "router_logits": OutputRecorder(BailingMoeV2_5TopkRouter, index=0), + "hidden_states": BailingMoeV2_5DecoderLayer, + "attentions": BailingMoeV2_5Attention, + } + _is_stateful = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, BailingMoeV2_5TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.e_score_correction_bias) + elif isinstance(module, BailingMoeV2_5Experts): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BailingMoeV2_5LightningAttention): + # Reinitialize the slope buffer from config + slopes = _build_slope_tensor(module.num_heads) + layer_scale = 1 - (module.layer_idx - 1) / (self.config.num_hidden_layers - 1) + 1e-5 + init.copy_(module.slope, (-slopes * layer_scale).to(module.slope.dtype)) + + +class BailingMoeV2_5Model(BailingMoeV2_5PreTrainedModel): + def __init__(self, config: BailingMoeV2_5Config): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [BailingMoeV2_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = BailingMoeV2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BailingMoeV2_5RotaryEmbedding(config=config) + self.rotary_emb_linear = BailingMoeV2_5LinearRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, past_key_values) + + hidden_states = inputs_embeds + position_embeddings_mla = self.rotary_emb(hidden_states, position_ids) + position_embeddings_linear = self.rotary_emb_linear(hidden_states, position_ids) + + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if self.config.layer_types[i] == "linear_attention": + layer_mask = linear_attn_mask + layer_position_embeddings = position_embeddings_linear + else: + layer_mask = causal_mask + layer_position_embeddings = position_embeddings_mla + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=layer_position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _update_linear_attn_mask(self, attention_mask, past_key_values): + """For linear attention, we only need a simple mask, not the full causal mask.""" + linear_attn_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + linear_attn_mask = None + return linear_attn_mask + + +def load_balancing_loss_func( + gate_logits: torch.Tensor | tuple[torch.Tensor] | None, + num_experts: int | None = None, + top_k=2, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor | int: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class BailingMoeV2_5ForCausalLM(BailingMoeV2_5PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = BailingMoeV2_5Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BailingMoeV2_5ForCausalLM + + >>> model = BailingMoeV2_5ForCausalLM.from_pretrained("inclusionAI/Ling-2.6-flash-base") + >>> tokenizer = AutoTokenizer.from_pretrained("inclusionAI/Ling-2.6-flash-base") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class BailingMoeV2_5ForSequenceClassification(GenericForSequenceClassification, BailingMoeV2_5PreTrainedModel): + pass + + +class BailingMoeV2_5ForTokenClassification(GenericForTokenClassification, BailingMoeV2_5PreTrainedModel): + pass + + +__all__ = [ + "BailingMoeV2_5PreTrainedModel", + "BailingMoeV2_5Model", + "BailingMoeV2_5ForCausalLM", + "BailingMoeV2_5ForSequenceClassification", + "BailingMoeV2_5ForTokenClassification", +] diff --git a/src/transformers/models/bailing2_5_moe/modular_bailing2_5_moe.py b/src/transformers/models/bailing2_5_moe/modular_bailing2_5_moe.py new file mode 100644 index 000000000000..accfcbe93b10 --- /dev/null +++ b/src/transformers/models/bailing2_5_moe/modular_bailing2_5_moe.py @@ -0,0 +1,682 @@ +# Copyright 2025 InclusionAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BailingMoeV2_5 model.""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...cache_utils import Cache, DynamicCache +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.generic import merge_with_config_defaults +from ...utils.import_utils import is_flash_linear_attention_available +from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..bamba.modeling_bamba import apply_mask_to_padding_states +from ..deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3Attention, + DeepseekV3MLP, + DeepseekV3MoE, + DeepseekV3NaiveMoe, + DeepseekV3RMSNorm, + DeepseekV3TopkRouter, + rotate_half, +) +from ..llama.modeling_llama import LlamaRotaryEmbedding +from ..mixtral.modeling_mixtral import MixtralForCausalLM +from .configuration_bailing2_5_moe import BailingMoeV2_5Config + + +if is_flash_linear_attention_available(): + from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla +else: + chunk_simple_gla, fused_recurrent_simple_gla = None, None + + +is_fast_path_available = all((chunk_simple_gla, fused_recurrent_simple_gla)) + +logger = logging.get_logger(__name__) + + +class BailingMoeV2_5RMSNorm(DeepseekV3RMSNorm): + pass + + +class BailingMoeV2_5RotaryEmbedding(LlamaRotaryEmbedding): + """RoPE for MLA layers — uses full qk_rope_head_dim, interleaved application.""" + + @staticmethod + def compute_default_rope_parameters( + config: BailingMoeV2_5Config | None = None, + device: torch.device | None = None, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, float]: + base = config.rope_parameters["rope_theta"] + dim = config.qk_rope_head_dim + + attention_factor = 1.0 + + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + +class BailingMoeV2_5LinearRotaryEmbedding(LlamaRotaryEmbedding): + """RoPE for linear attention layers — uses partial rotary factor on linear attention head dim.""" + + @staticmethod + def compute_default_rope_parameters( + config: BailingMoeV2_5Config | None = None, + device: torch.device | None = None, + seq_len: int | None = None, + ) -> tuple[torch.Tensor, float]: + base = config.rope_parameters["rope_theta"] + linear_head_dim = config.hidden_size // config.num_kv_heads_for_linear_attn + dim = int(linear_head_dim * config.partial_rotary_factor) + + attention_factor = 1.0 + + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + +class BailingMoeV2_5MLP(DeepseekV3MLP): + pass + + +class BailingMoeV2_5TopkRouter(DeepseekV3TopkRouter): + def __init__(self, config): + super().__init__(config) + self.n_routed_experts = config.num_experts + + +class BailingMoeV2_5Experts(DeepseekV3NaiveMoe): + pass + + +class BailingMoeV2_5MoE(DeepseekV3MoE): + def __init__(self, config): + nn.Module.__init__(self) + self.config = config + self.experts = BailingMoeV2_5Experts(config) + self.gate = BailingMoeV2_5TopkRouter(config) + self.shared_experts = BailingMoeV2_5MLP( + config=config, intermediate_size=config.moe_shared_expert_intermediate_size * config.num_shared_experts + ) + self.n_routed_experts = config.num_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + +class BailingMoeV2_5GroupRMSNorm(nn.Module): + """Group-wise RMS normalization for linear attention output.""" + + def __init__(self, hidden_size, group_norm_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_norm_size = group_norm_size + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + orig_shape = hidden_states.shape + # Reshape to groups: (..., group_norm_size, hidden_size // group_norm_size) + hidden_states = hidden_states.view(*orig_shape[:-1], self.group_norm_size, -1) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.view(orig_shape) + return (self.weight * hidden_states).to(input_dtype) + + +def _build_slope_tensor(num_heads: int) -> torch.Tensor: + """Build ALiBi-style slope tensor for lightning linear attention decay.""" + + def _get_interleave(n: int) -> list[float]: + def _get_interleave_power_of_2(n: int) -> list[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + slopes = torch.tensor(_get_interleave(num_heads), dtype=torch.float32) + return slopes + + +def torch_chunk_simple_gla( + query, + key, + value, + g, + chunk_size=64, + initial_state=None, + output_final_state=False, +): + """Pure PyTorch fallback for chunk_simple_gla when fla is not available.""" + initial_dtype = query.dtype + query, key, value = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value)] + g = g.transpose(1, 2).contiguous().to(torch.float32) + + batch_size, num_heads, sequence_length, head_dim = key.shape + v_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_len = sequence_length + pad_size + + scale = head_dim**-0.5 + query = query * scale + + # Reshape into chunks + num_chunks = total_len // chunk_size + query = query.view(batch_size, num_heads, num_chunks, chunk_size, head_dim) + key = key.view(batch_size, num_heads, num_chunks, chunk_size, head_dim) + value = value.view(batch_size, num_heads, num_chunks, chunk_size, v_dim) + g = g.view(batch_size, num_heads, num_chunks, chunk_size) + + # Cumulative decay within each chunk + g_cumsum = g.cumsum(dim=-1) + + state = ( + torch.zeros(batch_size, num_heads, head_dim, v_dim, device=query.device, dtype=torch.float32) + if initial_state is None + else initial_state.to(torch.float32) + ) + output = torch.zeros( + batch_size, num_heads, num_chunks, chunk_size, v_dim, device=query.device, dtype=torch.float32 + ) + + for c in range(num_chunks): + q_c = query[:, :, c] # [B, H, C, D] + k_c = key[:, :, c] + v_c = value[:, :, c] + g_c = g_cumsum[:, :, c] # [B, H, C] + + # Intra-chunk attention with decay + decay_matrix = (g_c.unsqueeze(-1) - g_c.unsqueeze(-2)).tril().exp() + attn = (q_c @ k_c.transpose(-1, -2)) * decay_matrix.tril() + intra = attn @ v_c + + # Inter-chunk: query attends to state from previous chunks + inter = (q_c * g_c.unsqueeze(-1).exp()) @ state + + output[:, :, c] = intra + inter + + # Update state with this chunk's contributions + chunk_end_decay = g_c[:, :, -1].unsqueeze(-1).unsqueeze(-1) + per_step_decay = (g_c[:, :, -1].unsqueeze(-1) - g_c).exp() # [B, H, C] + state = state * chunk_end_decay.exp() + (k_c * per_step_decay.unsqueeze(-1)).transpose(-1, -2) @ v_c + + output = output.view(batch_size, num_heads, total_len, v_dim)[:, :, :sequence_length] + output = output.transpose(1, 2).contiguous().to(initial_dtype) + + if not output_final_state: + state = None + return output, state + + +def torch_recurrent_simple_gla( + query, + key, + value, + g, + initial_state=None, + output_final_state=False, +): + """Pure PyTorch fallback for fused_recurrent_simple_gla.""" + initial_dtype = query.dtype + query, key, value = [x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value)] + g = g.transpose(1, 2).contiguous().to(torch.float32) + + batch_size, num_heads, sequence_length, head_dim = key.shape + v_dim = value.shape[-1] + scale = head_dim**-0.5 + query = query * scale + + state = ( + torch.zeros(batch_size, num_heads, head_dim, v_dim, device=query.device, dtype=torch.float32) + if initial_state is None + else initial_state.to(torch.float32) + ) + output = torch.zeros(batch_size, num_heads, sequence_length, v_dim, device=query.device, dtype=torch.float32) + + for t in range(sequence_length): + decay = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] + state = state * decay + key[:, :, t].unsqueeze(-1) * value[:, :, t].unsqueeze(-2) + output[:, :, t] = (query[:, :, t].unsqueeze(-1) * state).sum(dim=-2) + + if not output_final_state: + state = None + output = output.transpose(1, 2).contiguous().to(initial_dtype) + return output, state + + +def _apply_rotary_pos_emb_linear(q, k, cos, sin, unsqueeze_dim=2): + """Apply rotary position embedding with partial rotary support for linear attention. + Q/K are in [bsz, seq_len, n_heads, head_dim] format. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + q_embed = torch.cat([q_embed, q_pass], dim=-1) + k_embed = torch.cat([k_embed, k_pass], dim=-1) + return q_embed, k_embed + + +class BailingMoeV2_5LightningAttention(nn.Module): + """Lightning Linear Attention using SimpleGLA (Simple Gated Linear Attention) from the fla library.""" + + def __init__(self, config: BailingMoeV2_5Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_kv_heads_for_linear_attn + self.head_dim = config.hidden_size // self.num_heads + + self.query_key_value = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias) + self.g_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + + self.g_norm = BailingMoeV2_5GroupRMSNorm(config.hidden_size, config.group_norm_size, eps=config.rms_norm_eps) + + if config.use_qk_norm: + self.query_layernorm = BailingMoeV2_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = BailingMoeV2_5RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + # Build ALiBi-style slopes for decay, scaled by layer position + slopes = _build_slope_tensor(self.num_heads) + layer_scale = 1 - (layer_idx - 1) / (config.num_hidden_layers - 1) + 1e-5 + self.register_buffer("slope", (-slopes * layer_scale).to(torch.float32), persistent=False) + + self.chunk_simple_gla = chunk_simple_gla or torch_chunk_simple_gla + self.recurrent_simple_gla = fused_recurrent_simple_gla or torch_recurrent_simple_gla + + if not is_fast_path_available: + logger.warning_once( + "The fast path for BailingMoeV2_5LightningAttention is not available because flash-linear-attention " + "is not installed. Falling back to pure PyTorch implementation. " + "To install, see https://github.com/fla-org/flash-linear-attention#installation" + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + cache_params: Cache | None = None, + attention_mask: torch.Tensor | None = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + bsz, q_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None and cache_params.has_previous_state(self.layer_idx) and q_len == 1 + ) + + # Fused QKV projection + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(bsz, q_len, 3, self.num_heads, self.head_dim) + query_states = qkv[:, :, 0] + key_states = qkv[:, :, 1] + value_states = qkv[:, :, 2] + + # Apply QK norm per head before RoPE (matches training-time behaviour) + if self.config.use_qk_norm: + query_states = self.query_layernorm(query_states) + key_states = self.key_layernorm(key_states) + + # Apply partial RoPE + cos, sin = position_embeddings + query_states, key_states = _apply_rotary_pos_emb_linear(query_states, key_states, cos, sin) + + # Gate projection + g_proj = self.g_proj(hidden_states) + + # Compute decay from slopes + g = self.slope[None, None, :].expand(bsz, q_len, self.num_heads) + + if use_precomputed_states: + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states + attn_output, last_state = self.recurrent_simple_gla( + query_states, + key_states, + value_states, + g=g, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + ) + else: + attn_output, last_state = self.chunk_simple_gla( + query_states, + key_states, + value_states, + g=g, + initial_state=None, + output_final_state=cache_params is not None, + ) + + if cache_params is not None: + # For models without conv1d, we need to ensure dtype/device are set + # on the cache layer before updating recurrent state + layer = cache_params.layers[self.layer_idx] + if not hasattr(layer, "dtype") or layer.dtype is None: + layer.dtype = last_state.dtype + layer.device = last_state.device + cache_params.update_recurrent_state(last_state, self.layer_idx) + # SimpleGLA has no conv1d, so update_conv_state is never called. + # We manually set has_previous_state so decode uses the recurrent path. + layer.has_previous_state = True + + # Reshape from [bsz, q_len, num_heads, head_dim] to [bsz*q_len, hidden_size] + attn_output = attn_output.reshape(bsz * q_len, self.hidden_size) + attn_output = self.g_norm(attn_output) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + if self.config.linear_silu: + attn_output = attn_output * F.silu(g_proj) + else: + attn_output = attn_output * torch.sigmoid(g_proj) + + attn_output = self.o_proj(attn_output) + return attn_output + + +class BailingMoeV2_5Attention(DeepseekV3Attention): + """MLA (Multi-Latent Attention) inherited from DeepSeek V3.""" + + pass + + +class BailingMoeV2_5DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: BailingMoeV2_5Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + # Token mixer: MLA or Lightning Linear Attention + self.layer_type = config.layer_types[layer_idx] + if self.layer_type == "full_attention": + self.self_attn = BailingMoeV2_5Attention(config=config, layer_idx=layer_idx) + else: + self.linear_attn = BailingMoeV2_5LightningAttention(config=config, layer_idx=layer_idx) + + # MLP: MoE or Dense + if layer_idx >= config.first_k_dense_replace: + self.mlp = BailingMoeV2_5MoE(config) + else: + self.mlp = BailingMoeV2_5MLP(config) + + self.input_layernorm = BailingMoeV2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = BailingMoeV2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.FloatTensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + cache_params=past_key_values, + attention_mask=attention_mask, + ) + elif self.layer_type == "full_attention": + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class BailingMoeV2_5PreTrainedModel(PreTrainedModel): + config: BailingMoeV2_5Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BailingMoeV2_5DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn = True + _supports_sdpa = True + _can_record_outputs = { + "router_logits": OutputRecorder(BailingMoeV2_5TopkRouter, index=0), + "hidden_states": BailingMoeV2_5DecoderLayer, + "attentions": BailingMoeV2_5Attention, + } + _is_stateful = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, BailingMoeV2_5TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.e_score_correction_bias) + elif isinstance(module, BailingMoeV2_5Experts): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BailingMoeV2_5LightningAttention): + # Reinitialize the slope buffer from config + slopes = _build_slope_tensor(module.num_heads) + layer_scale = 1 - (module.layer_idx - 1) / (self.config.num_hidden_layers - 1) + 1e-5 + init.copy_(module.slope, (-slopes * layer_scale).to(module.slope.dtype)) + + +class BailingMoeV2_5Model(BailingMoeV2_5PreTrainedModel): + def __init__(self, config: BailingMoeV2_5Config): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.layers = nn.ModuleList( + [BailingMoeV2_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = BailingMoeV2_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = BailingMoeV2_5RotaryEmbedding(config=config) + self.rotary_emb_linear = BailingMoeV2_5LinearRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + linear_attn_mask = self._update_linear_attn_mask(attention_mask, past_key_values) + + hidden_states = inputs_embeds + position_embeddings_mla = self.rotary_emb(hidden_states, position_ids) + position_embeddings_linear = self.rotary_emb_linear(hidden_states, position_ids) + + for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + if self.config.layer_types[i] == "linear_attention": + layer_mask = linear_attn_mask + layer_position_embeddings = position_embeddings_linear + else: + layer_mask = causal_mask + layer_position_embeddings = position_embeddings_mla + + hidden_states = decoder_layer( + hidden_states, + position_embeddings=layer_position_embeddings, + attention_mask=layer_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + def _update_linear_attn_mask(self, attention_mask, past_key_values): + """For linear attention, we only need a simple mask, not the full causal mask.""" + linear_attn_mask = attention_mask + if (past_key_values is not None and past_key_values.has_previous_state()) or ( + attention_mask is not None and torch.all(attention_mask == 1) + ): + linear_attn_mask = None + return linear_attn_mask + + +class BailingMoeV2_5ForCausalLM(MixtralForCausalLM): + def __init__(self, config): + super().__init__(config) + self.num_experts = config.num_experts + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, BailingMoeV2_5ForCausalLM + + >>> model = BailingMoeV2_5ForCausalLM.from_pretrained("inclusionAI/Ling-2.6-flash-base") + >>> tokenizer = AutoTokenizer.from_pretrained("inclusionAI/Ling-2.6-flash-base") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_router_logits=output_router_logits, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + +class BailingMoeV2_5ForSequenceClassification(GenericForSequenceClassification, BailingMoeV2_5PreTrainedModel): + pass + + +class BailingMoeV2_5ForTokenClassification(GenericForTokenClassification, BailingMoeV2_5PreTrainedModel): + pass + + +__all__ = [ + "BailingMoeV2_5PreTrainedModel", + "BailingMoeV2_5Model", + "BailingMoeV2_5ForCausalLM", + "BailingMoeV2_5ForSequenceClassification", + "BailingMoeV2_5ForTokenClassification", +] diff --git a/tests/models/bailing2_5_moe/__init__.py b/tests/models/bailing2_5_moe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/bailing2_5_moe/test_modeling_bailing2_5_moe.py b/tests/models/bailing2_5_moe/test_modeling_bailing2_5_moe.py new file mode 100644 index 000000000000..6f644fc94d4d --- /dev/null +++ b/tests/models/bailing2_5_moe/test_modeling_bailing2_5_moe.py @@ -0,0 +1,371 @@ +# Copyright 2025 InclusionAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch BailingMoeV2_5 model.""" + +import unittest + +from parameterized import parameterized + +from transformers import BailingMoeV2_5Config, is_torch_available +from transformers.testing_utils import ( + require_torch, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + BailingMoeV2_5ForCausalLM, + BailingMoeV2_5ForSequenceClassification, + BailingMoeV2_5ForTokenClassification, + BailingMoeV2_5Model, + ) + + +class BailingMoeV2_5ModelTester: + if is_torch_available(): + causal_lm_class = BailingMoeV2_5ForCausalLM + + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + intermediate_size=32, + moe_intermediate_size=16, + moe_shared_expert_intermediate_size=16, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + num_shared_experts=1, + num_experts=8, + routed_scaling_factor=2.5, + kv_lora_rank=16, + q_lora_rank=32, + qk_rope_head_dim=4, + v_head_dim=8, + qk_nope_head_dim=4, + n_group=2, + topk_group=1, + num_experts_per_tok=2, + first_k_dense_replace=1, + norm_topk_prob=True, + layer_group_size=4, + group_norm_size=2, + num_kv_heads_for_linear_attn=4, + linear_silu=False, + hidden_act="silu", + max_position_embeddings=512, + initializer_range=0.02, + attention_probs_dropout_prob=0.0, + type_vocab_size=16, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.num_shared_experts = num_shared_experts + self.num_experts = num_experts + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.layer_group_size = layer_group_size + self.group_norm_size = group_norm_size + self.num_kv_heads_for_linear_attn = num_kv_heads_for_linear_attn + self.linear_silu = linear_silu + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return BailingMoeV2_5Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + moe_intermediate_size=self.moe_intermediate_size, + moe_shared_expert_intermediate_size=self.moe_shared_expert_intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + num_shared_experts=self.num_shared_experts, + num_experts=self.num_experts, + routed_scaling_factor=self.routed_scaling_factor, + kv_lora_rank=self.kv_lora_rank, + q_lora_rank=self.q_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + qk_nope_head_dim=self.qk_nope_head_dim, + n_group=self.n_group, + topk_group=self.topk_group, + num_experts_per_tok=self.num_experts_per_tok, + first_k_dense_replace=self.first_k_dense_replace, + norm_topk_prob=self.norm_topk_prob, + layer_group_size=self.layer_group_size, + group_norm_size=self.group_norm_size, + num_kv_heads_for_linear_attn=self.num_kv_heads_for_linear_attn, + linear_silu=self.linear_silu, + hidden_act=self.hidden_act, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + use_cache=True, + pad_token_id=self.pad_token_id, + attention_dropout=self.attention_probs_dropout_prob, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = BailingMoeV2_5Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class BailingMoeV2_5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + BailingMoeV2_5Model, + BailingMoeV2_5ForCausalLM, + BailingMoeV2_5ForSequenceClassification, + BailingMoeV2_5ForTokenClassification, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (BailingMoeV2_5ForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": BailingMoeV2_5Model, + "text-classification": BailingMoeV2_5ForSequenceClassification, + "token-classification": BailingMoeV2_5ForTokenClassification, + "text-generation": BailingMoeV2_5ForCausalLM, + "zero-shot": BailingMoeV2_5ForSequenceClassification, + } + if is_torch_available() + else {} + ) + + model_split_percents = [0.5, 0.7, 0.8] + + _torch_compile_train_cls = BailingMoeV2_5ForCausalLM if is_torch_available() else None + + def setUp(self): + self.model_tester = BailingMoeV2_5ModelTester(self) + self.config_tester = ConfigTester(self, config_class=BailingMoeV2_5Config, hidden_size=32) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_attention_outputs(self): + """Needs override as BailingMoeV2_5 alternates between MLA and linear attention layers.""" + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + config._attn_implementation = "eager" + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), sum(layer == "full_attention" for layer in config.layer_types)) + out_len = len(outputs) + + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + self_attentions = outputs.attentions + + self.assertEqual(out_len + 1, len(outputs)) + self.assertEqual(len(self_attentions), sum(layer == "full_attention" for layer in config.layer_types)) + + @parameterized.expand([("random",), ("same",)]) + @unittest.skip("BailingMoeV2_5 is not compatible with assisted decoding due to hybrid cache") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("BailingMoeV2_5 is not compatible with assisted decoding due to hybrid cache") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("BailingMoeV2_5 is not compatible with assisted decoding due to hybrid cache") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so it is not compatible with the standard cache format") + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so it is not compatible with the standard cache format") + def test_greedy_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="SDPA can't dispatch on flash due to unsupported head dims") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so beam search is not compatible with the standard cache format") + def test_beam_sample_generate(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so beam search is not compatible with the standard cache format") + def test_beam_search_generate(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so beam search is not compatible with the standard cache format") + def test_beam_sample_generate_dict_output(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so beam search is not compatible with the standard cache format") + def test_beam_search_generate_dict_output(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so it is not compatible with continue from past_key_values") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("BailingMoeV2_5's linear attention has no conv1d, so conv_states are None") + def test_past_key_values_format(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so inputs_embeds generation is not compatible with cache format") + def test_generate_from_inputs_embeds_0_greedy(self): + pass + + @unittest.skip("BailingMoeV2_5 uses MLA so inputs_embeds generation is not compatible with cache format") + def test_generate_from_inputs_embeds_1_beam(self): + pass + + @unittest.skip("The specific cache format cannot be instantiated from dp/ddp data.") + def test_multi_gpu_data_parallel_forward(self): + pass + + def test_bailing2_5_moe_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_labels) + model = BailingMoeV2_5ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1d590e43a9cd..5d54305dc0cc 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -827,6 +827,7 @@ def test_num_layers_is_small(self): "Gemma3nVision2TextModelTest": 4, # need to test KV shared layer for both types: `full_attention` and `sliding_attention` "BeitModelTest": 4, # BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers "ZambaModelTest": 5, # The minimum number to test beyond the initial ["mamba", "mamba", "hybrid"] in `ZambaConfig._layers_block_type` + "BailingMoeV2_5ModelTest": 4, # need at least 4 layers (layer_group_size=4) to test both full_attention and linear_attention } target_num_hidden_layers = exceptional_num_hidden_layers.get(type(self).__name__, 2) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 9b176987ccc6..dbbfae5eadb4 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -64,6 +64,9 @@ "Lfm2Config": ["full_attn_idxs"], "DiaConfig": ["delay_pattern"], "BambaConfig": ["attn_layer_indices"], + # layer_group_size builds `layer_types` in __post_init__ (and drives weight conversion); scoring_func/topk_method + # describe the router behavior the model hardcodes (sigmoid + noaux_tc), kept for checkpoint config compatibility. + "BailingMoeV2_5Config": ["layer_group_size", "scoring_func", "topk_method"], "Dots1Config": ["max_window_layers"], "JambaConfig": ["attn_layer_offset", "attn_layer_period", "expert_layer_offset", "expert_layer_period"], "JetMoeConfig": ["output_router_logits"],