diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e8cbed059c1e..5b0946b32743 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -557,6 +557,8 @@ title: DeepSeek-V2 - local: model_doc/deepseek_v3 title: DeepSeek-V3 + - local: model_doc/deepseek_v4 + title: DeepSeek-V4 - local: model_doc/dialogpt title: DialoGPT - local: model_doc/diffllama diff --git a/docs/source/en/model_doc/deepseek_v4.md b/docs/source/en/model_doc/deepseek_v4.md new file mode 100644 index 000000000000..1ee30ea7f78f --- /dev/null +++ b/docs/source/en/model_doc/deepseek_v4.md @@ -0,0 +1,112 @@ + +*This model was released on {release_date} and added to Hugging Face Transformers on 2026-04-28.* + +# DeepSeek-V4 + +[DeepSeek-V4](https://huggingface.co/deepseek-ai) is the next-generation MoE language model from DeepSeek +([paper](https://huggingface.co/deepseek-ai/DeepSeek-V4-Flash/blob/main/DeepSeek_V4.pdf)). The architecture replaces +DeepSeek-V3's Multi-head Latent Attention (MLA) with a hybrid local + long-range design, swaps residual connections +for Manifold-Constrained Hyper-Connections (mHC), and bootstraps the first few MoE layers with a static +token-id → expert-id hash table. + +This implementation covers `DeepSeek-V4-Flash`, `DeepSeek-V4-Pro`, and their `-Base` pretrained siblings. All four +share the same architecture; they differ only in width / depth / expert count and weights. + +## Architecture (paper §2) + +### Hybrid attention (§2.3) + +Each decoder block is one of three attention types, dispatched by `config.layer_types[i]`: + +* **Sliding-window full attention** (`"sliding_attention"`): only the local window of `sliding_window` tokens, no + long-range branch. Matches V3's "Full Attention" style for the bootstrap layers. +* **Compressed Sparse Attention** (`"compressed_sparse_attention"`, **CSA** — paper §2.3.1): a low-compression + pool (`compress_rate_csa`, default `m=4`) with overlapping windows, plus a **Lightning Indexer** (eqs. 13–17) + that scores queries against the pool and gathers the top `index_topk` blocks per query before they reach core + attention. +* **Heavily Compressed Attention** (`"heavily_compressed_attention"`, **HCA** — paper §2.3.2): a high-compression + pool (`compress_rate_hca`, default `m'=128`) with non-overlapping windows. No indexer — every pooled entry + contributes to attention. + +All three types share the same backbone: + +* **Shared K=V Multi-Query Attention**: `num_key_value_heads = 1`; `kv_proj` produces a single KV head and the same + tensor is read as both key and value. +* **Partial RoPE** (interleaved-pair, paper §2.3.3 "Partial Rotary Positional Embedding") on the trailing + `qk_rope_head_dim = head_dim * partial_rotary_factor` channels of each head. The same rotation is applied with + position `-i` to the attention output's rope slice (eq. 26) so the contribution of each KV entry stays a function + of the *relative* distance to the query. +* **Per-head learnable attention sink** (eq. 27). +* **Grouped low-rank output projection** (§2.3.1 "Grouped Output Projection"): `o_groups` head-groups → `o_lora_rank` + per group → `hidden_size`, computed by [`DeepseekV4GroupedLinear`] (`o_a_proj`) followed by `o_b_proj`. Cuts the + per-token cost of the wide attention output without losing expressivity. +* **Shared sliding-window K=V branch** of size `sliding_window` ("Additional Branch of Sliding Window Attention", + §2.3.1) preserves local fine-grained dependencies; the long-range compressor's output is concatenated with this + branch's KVs before core attention. + +### Manifold-Constrained Hyper-Connections (§2.2) + +Residual connections are replaced by mHC (Xie et al., 2026): `hc_mult` parallel residual streams kept in shape +`[B, S, hc_mult, D]` throughout each block. Two [`DeepseekV4HyperConnection`] modules — `attn_hc` and `ffn_hc` — mix +streams in and out around the attention / MLP sublayers via a `(pre, post, comb)` triplet. The `comb` matrix is a +doubly-stochastic projection produced by `hc_sinkhorn_iters` Sinkhorn–Knopp iterations on the manifold, making +signal propagation non-expansive across deep stacks. A final [`DeepseekV4HyperHead`] collapses the `hc_mult` +streams down to a single sequence before the model norm. + +### MoE schedule (§2.1) + +Routing is configured per layer by `config.mlp_layer_types`, with values from `{"hash_moe", "moe"}`: + +* `"hash_moe"`: expert indices come from a frozen `tid2eid[input_ids]` lookup populated from the V4 checkpoint. + The learned gate `weight` still produces the per-expert scores that weight the selected experts; only + *which-experts* is static. Used for the first few bootstrap layers (default 3, override via legacy + `num_hash_layers`). +* `"moe"`: standard top-k routed MoE. The expert affinity uses **Sqrt(Softplus(·))** instead of V3's Sigmoid + ("we change the activation function that computes the affinity scores from Sigmoid(·) into Sqrt(Softplus(·))", + paper §2.1), and V3's `n_group` / `topk_group` constraint is dropped. The auxiliary-loss-free strategy + (DeepSeek's `noaux_tc`) is preserved via the `e_score_correction_bias` buffer that biases the top-k argmax + without flowing gradients. + +Routed experts use a **clamped SwiGLU** (`gate.clamp(max=swiglu_limit)`, `up.clamp(min=-swiglu_limit, max=swiglu_limit)`, +then `act_fn(gate) * up`) on top of the standard Mixtral `[num_experts, 2 * moe_intermediate_size, hidden_size]` +expert weight layout. A single shared expert (a plain SwiGLU MLP at `moe_intermediate_size` width) runs in parallel +on every token. + +### Cache layers + +Each non-sliding attention block needs to thread compressor / indexer state across forward calls. V4 ships two +cache layer types that auto-register with `LAYER_TYPE_CACHE_MAPPING`: + +* `DeepseekV4HCACache`: sliding-window K=V + HCA compressor buffer / pool / count (no overlap, no indexer). +* `DeepseekV4CSACache`: sliding-window K=V + CSA compressor (with overlap state) + parallel indexer + buffer / pool / count / overlap at `index_head_dim`. + +`DynamicCache(config=…)` builds the right cache layer per `config.layer_types[i]`. + +## DeepseekV4Config + +[[autodoc]] DeepseekV4Config + +## DeepseekV4Model + +[[autodoc]] DeepseekV4Model + - forward + +## DeepseekV4ForCausalLM + +[[autodoc]] DeepseekV4ForCausalLM + - forward diff --git a/src/transformers/activations.py b/src/transformers/activations.py index a51ebca341d4..d2d1b362f79a 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -214,6 +214,13 @@ def forward(self, input): return squared +class SqrtSoftplusActivation(nn.Module): + """sqrt(softplus(x)) — the router scoring function used by DeepSeek V4.""" + + def forward(self, input): + return nn.functional.softplus(input).sqrt() + + class ClassInstantier(OrderedDict): def __getitem__(self, key): content = super().__getitem__(key) @@ -334,6 +341,7 @@ def forward(self, input: Tensor) -> Tensor: "relu6": nn.ReLU6, "sigmoid": nn.Sigmoid, "silu": SiLUActivation, + "sqrtsoftplus": SqrtSoftplusActivation, "swish": nn.SiLU, "tanh": nn.Tanh, "prelu": nn.PReLU, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 95a47ae39fdf..b80dfa7d2267 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -23,10 +23,31 @@ logger = logging.get_logger(__name__) +# Registry mapping ``config.layer_types[i]`` -> the dynamic cache layer class to build for +# that layer. ``DynamicCache.__init__`` consults this mapping when a ``config`` is provided +# so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own +# cache-layer subclass and stop needing a model-specific ``Cache`` subclass. +# +# A cache layer subclass with a class attribute ``layer_type = "..."`` auto-registers via +# ``CacheLayerMixin.__init_subclass__``. Each registered class must accept a +# ``PreTrainedConfig`` (the decoder text config) as the only positional argument. +LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {} + + class CacheLayerMixin(ABC): """Base, abstract class for a single layer's cache.""" is_compileable = False + # Subclasses can set ``layer_type`` to auto-register themselves in + # ``LAYER_TYPE_CACHE_MAPPING`` at import time (used by ``DynamicCache`` to dispatch + # per-layer cache classes from ``config.layer_types``). + layer_type: str | None = None + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + layer_type = cls.__dict__.get("layer_type", None) + if layer_type is not None: + LAYER_TYPE_CACHE_MAPPING[layer_type] = cls def __init__(self): self.keys: torch.Tensor | None = None @@ -93,6 +114,9 @@ class DynamicLayer(CacheLayerMixin): is_sliding = False + def __init__(self, config: PreTrainedConfig | None = None): + super().__init__() + def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: self.dtype, self.device = key_states.dtype, key_states.device self.keys = torch.tensor([], dtype=self.dtype, device=self.device) @@ -171,8 +195,14 @@ class DynamicSlidingWindowLayer(DynamicLayer): is_sliding = True - def __init__(self, sliding_window: int): + def __init__(self, config: PreTrainedConfig | None = None, sliding_window: int | None = None): super().__init__() + # Accept either a config (registry-style construction via LAYER_TYPE_CACHE_MAPPING) + # or a raw ``sliding_window`` int (legacy callers). + if sliding_window is None: + if config is None: + raise ValueError("Either `config` or `sliding_window` must be provided.") + sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None) self.sliding_window = sliding_window self.cumulative_length = 0 self._sliding_window_tensor = torch.tensor(self.sliding_window, dtype=torch.long) @@ -732,6 +762,9 @@ def crop(self, max_length: int): class LinearAttentionLayer(LinearAttentionCacheLayerMixin): + def __init__(self, config: PreTrainedConfig | None = None): + super().__init__() + def lazy_initialization( self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None ) -> None: @@ -808,7 +841,7 @@ class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer): # The dynamic Attention part makes it non-compileable is_compileable = False - def __init__(self): + def __init__(self, config: PreTrainedConfig | None = None): DynamicLayer.__init__(self) LinearAttentionLayer.__init__(self) @@ -831,6 +864,29 @@ def reorder_cache(self, beam_idx: torch.LongTensor): DynamicLayer.reorder_cache(self, beam_idx) +# Pre-register the standard layer types (some classes are shared between multiple types, +# e.g. ``DynamicSlidingWindowLayer`` covers both ``"sliding_attention"`` and +# ``"chunked_attention"`` — those need an explicit map entry rather than the +# auto-registration via ``CacheLayerMixin.__init_subclass__``). +LAYER_TYPE_CACHE_MAPPING.update( + { + "full_attention": DynamicLayer, + # From a cache point of view, sliding and chunked are the same in how they should behave; + # only the mask differs. + "sliding_attention": DynamicSlidingWindowLayer, + "chunked_attention": DynamicSlidingWindowLayer, + # Linear-attention-shaped layers (mamba / conv / pure linear-attention / moe placeholders) + # don't grow per-token KV; they're tracked just so position bookkeeping stays consistent. + "mamba": LinearAttentionLayer, + "conv": LinearAttentionLayer, + "linear_attention": LinearAttentionLayer, + "moe": LinearAttentionLayer, + # Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state. + "hybrid": LinearAttentionAndFullAttentionLayer, + } +) + + class Cache: """ A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for @@ -1240,20 +1296,13 @@ def __init__( layer_types = layer_types[: -decoder_config.num_kv_shared_layers] for layer_type in layer_types: - # From a cache point of view, both sliding and chunked are the same in how they should behave and how many - # states they should return - only the mask changes to make them different at the end! - if layer_type in ("sliding_attention", "chunked_attention"): - layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window)) - # Note: we want moe layers to be LinearAttentionLayer, so that we can correctly grab sequence length etc from attention layers. - # Since moe layers will stay empty (they don't need any cache), we don't want them to collide for mask creation etc - # TODO: maybe use a dummy layer in those cases, or a dictionary {idx: Layer} for self.layers, so that we can skip - # the indices we don't need - elif layer_type in ("mamba", "conv", "linear_attention", "moe"): - layers.append(LinearAttentionLayer()) - elif layer_type == "hybrid": - layers.append(LinearAttentionAndFullAttentionLayer()) - else: - layers.append(DynamicLayer()) + # Dispatch through the registry — ``LAYER_TYPE_CACHE_MAPPING`` ships with the + # standard layer types pre-registered, and models with custom layer types + # (e.g. DeepSeek-V4's CSA / HCA) register their own classes there. Each class + # is instantiated with the decoder config so it can read whatever attributes + # it needs (sliding_window, compress_rate, ...). + cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer) + layers.append(cache_cls(decoder_config)) # In this case, use the passed data to already fill in the Cache if ddp_cache_data is not None: @@ -1353,7 +1402,7 @@ def __init__( layers = [] for layer_type in layer_types: - if layer_type == "sliding_attention": + if layer_type in ("sliding_attention", "compressed_sparse_attention", "heavily_compressed_attention"): layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window) elif layer_type == "chunked_attention": # From a cache point of view, both sliding and chunked are the same in how they should behave and how many diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 2dcdc5333f35..30377c75df6e 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -63,6 +63,8 @@ "full_attention", "sliding_attention", "chunked_attention", + "compressed_sparse_attention", # CSA, used in deepseek_v4 + "heavily_compressed_attention", # HCA, used in deepseek_v4 "linear_attention", # used in minimax "conv", # used in LFMv2 "mamba", diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index dadfeb4224ad..c6de330adc25 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -97,6 +97,183 @@ def _build_checkpoint_conversion_mapping(): "altclip": [ WeightRenaming(source_patterns=r"layer\.", target_patterns="layers."), ], + "deepseek_v4": [ + # Upstream checkpoint uses a flatter, V3-style namespace: ``attn`` / ``ffn`` + # instead of ``self_attn`` / ``mlp``, ``attn_norm`` / ``ffn_norm`` instead of + # ``input_layernorm`` / ``post_attention_layernorm``, ``hc_attn_*`` / ``hc_ffn_*`` + # for the Hyper-Connection params (we wrap them in ``attn_hc`` / ``ffn_hc`` + # submodules), ``embed`` / ``head`` / bare ``norm`` for the model head, and + # ``hc_head_*`` for the final HC collapse. The Indexer's compressor tree is + # nested under ``attn.indexer.compressor.*`` upstream but flattened onto the + # Indexer module here. FP8 scales arrive as ``.scale`` and need to become + # ``.weight_scale_inv`` to match :class:`FineGrainedFP8Linear`. + # + # Ordering matters for save round-tripping: :func:`revert_weight_conversion` + # reverses the order *and* each transform, so a structural prefix-only rule + # placed before a specific in-prefix rename would steal the reverse match + # and emit ``layers.X.attn.sinks`` instead of ``layers.X.attn.attn_sink``. + # We split into two passes: structural prefix renames first (so they apply + # last on save / first on load), then specific in-prefix renames that + # operate on the already-prefixed keys. + # + # FP8 ``.scale`` → ``.weight_scale_inv`` rename lives in the FP8 quantizer's + # ``update_weight_conversions`` (only kicks in when FP8 dequant is active), + # so the V4 static mapping below stays free of FP8-only rules. + # ---- Pass 1: top-level + structural prefix renames ---- + WeightRenaming(source_patterns=r"^embed\.weight$", target_patterns="model.embed_tokens.weight"), + WeightRenaming(source_patterns=r"^head\.weight$", target_patterns="lm_head.weight"), + WeightRenaming(source_patterns=r"^norm\.weight$", target_patterns="model.norm.weight"), + WeightRenaming(source_patterns=r"^hc_head_fn$", target_patterns="model.hc_head.hc_fn"), + WeightRenaming(source_patterns=r"^hc_head_base$", target_patterns="model.hc_head.hc_base"), + WeightRenaming(source_patterns=r"^hc_head_scale$", target_patterns="model.hc_head.hc_scale"), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.attn_norm\.", + target_patterns=r"model.layers.\1.input_layernorm.", + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.ffn_norm\.", + target_patterns=r"model.layers.\1.post_attention_layernorm.", + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.hc_attn_fn$", target_patterns=r"model.layers.\1.attn_hc.fn" + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.hc_attn_base$", target_patterns=r"model.layers.\1.attn_hc.base" + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.hc_attn_scale$", target_patterns=r"model.layers.\1.attn_hc.scale" + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.hc_ffn_fn$", target_patterns=r"model.layers.\1.ffn_hc.fn" + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.hc_ffn_base$", target_patterns=r"model.layers.\1.ffn_hc.base" + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.hc_ffn_scale$", target_patterns=r"model.layers.\1.ffn_hc.scale" + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.attn\.", + target_patterns=r"model.layers.\1.self_attn.", + ), + WeightRenaming( + source_patterns=r"^layers\.(\d+)\.ffn\.", + target_patterns=r"model.layers.\1.mlp.", + ), + # ---- Pass 2: in-prefix specific renames (operate on already-prefixed keys) ---- + # These can safely run after the structural prefix renames because their + # source patterns include the ``model.layers.X.self_attn.`` / ``model.layers.X.mlp.`` + # prefix. On reverse the order flips so these undo first, restoring the + # specific upstream names *before* the structural rules strip the prefix. + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.attn_sink$", + target_patterns=r"model.layers.\1.self_attn.sinks", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.indexer\.compressor\.norm\.", + target_patterns=r"model.layers.\1.self_attn.compressor.indexer.kv_norm.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.indexer\.compressor\.ape$", + target_patterns=r"model.layers.\1.self_attn.compressor.indexer.position_bias", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.indexer\.compressor\.", + target_patterns=r"model.layers.\1.self_attn.compressor.indexer.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.indexer\.", + target_patterns=r"model.layers.\1.self_attn.compressor.indexer.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.compressor\.norm\.", + target_patterns=r"model.layers.\1.self_attn.compressor.kv_norm.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.compressor\.ape$", + target_patterns=r"model.layers.\1.self_attn.compressor.position_bias", + ), + # Attention / compressor / indexer leaf weights: upstream uses paper notation + # (``wq_a`` / ``wq_b`` / ``wkv`` / ``wo_a`` / ``wo_b`` / ``wgate``); we + # rename to the standard transformers ``*_proj`` form. Compressor / Indexer + # ``wkv`` / ``wgate`` are caught by the same patterns since they sit under + # ``self_attn.`` after the Pass 1 prefix rewrite. + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.(.*?)\.wq_a\.", + target_patterns=r"model.layers.\1.self_attn.\2.q_a_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.(.*?)\.wq_b\.", + target_patterns=r"model.layers.\1.self_attn.\2.q_b_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.(.*?)\.wkv\.", + target_patterns=r"model.layers.\1.self_attn.\2.kv_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.(.*?)\.wgate\.", + target_patterns=r"model.layers.\1.self_attn.\2.gate_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.(.*?)\.wo_a\.", + target_patterns=r"model.layers.\1.self_attn.\2.o_a_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.(.*?)\.wo_b\.", + target_patterns=r"model.layers.\1.self_attn.\2.o_b_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.wq_a\.", + target_patterns=r"model.layers.\1.self_attn.q_a_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.wq_b\.", + target_patterns=r"model.layers.\1.self_attn.q_b_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.wkv\.", + target_patterns=r"model.layers.\1.self_attn.kv_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.wo_a\.", + target_patterns=r"model.layers.\1.self_attn.o_a_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.self_attn\.wo_b\.", + target_patterns=r"model.layers.\1.self_attn.o_b_proj.", + ), + # Aux-loss-free routing bias: upstream ships ``gate.bias`` (V3 convention); + # we register it as ``e_score_correction_bias`` (cross-model standard name). + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.mlp\.gate\.bias$", + target_patterns=r"model.layers.\1.mlp.gate.e_score_correction_bias", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.mlp\.shared_experts\.w1\.", + target_patterns=r"model.layers.\1.mlp.shared_experts.gate_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.mlp\.shared_experts\.w2\.", + target_patterns=r"model.layers.\1.mlp.shared_experts.down_proj.", + ), + WeightRenaming( + source_patterns=r"^model\.layers\.(\d+)\.mlp\.shared_experts\.w3\.", + target_patterns=r"model.layers.\1.mlp.shared_experts.up_proj.", + ), + WeightConverter( + source_patterns=[ + "experts.*.w1.weight", + "experts.*.w3.weight", + ], + target_patterns="experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_patterns="experts.*.w2.weight", + target_patterns="experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + ], "llava": [ WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), @@ -687,8 +864,11 @@ def get_model_conversion_mapping( if add_legacy: weight_conversions.extend(get_checkpoint_conversion_mapping("legacy")) - # Add the ones from the quantizer as well if provided + # Let the quantizer rewrite / augment the conversion pipeline. This is where the + # FP8 dequantizer (when ``dequantize=True``) prepends a ``Fp8Dequantize`` op to + # every existing converter so that per-block scales are applied *before* any + # expert-merge / concat ops flatten the per-expert structure away. if hf_quantizer is not None: - weight_conversions.extend(hf_quantizer.get_weight_conversions()) + weight_conversions = hf_quantizer.update_weight_conversions(weight_conversions) return weight_conversions diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index cd0710649c91..347a80130826 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -156,8 +156,11 @@ def convert( target_pattern = self.get_target_pattern(target_patterns) all_tensors = [] # Very important to keep the relative order of the source patterns here, so we iterate over them not the - # input directly as it's unordered! + # input directly as it's unordered! Skip patterns that prior ops in the chain (e.g. ``Fp8Dequantize``) + # have already consumed and dropped from ``input_dict``. for source_pattern in source_patterns: + if source_pattern not in input_dict: + continue tensors = input_dict[source_pattern] if isinstance(tensors, list): all_tensors.extend(tensors) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index c64f1ce23ec2..66c4366d1726 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -16,7 +16,7 @@ from torch.nn import functional as F from ..activations import ACT2FN -from ..core_model_loading import ConversionOps, _IdentityOp +from ..core_model_loading import ConversionOps from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import @@ -809,12 +809,7 @@ class Fp8Quantize(ConversionOps): def __init__(self, hf_quantizer): self.hf_quantizer = hf_quantizer - def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]: - # Unpack single key/value (value may be wrapped in a list) - target_keys, value = tuple(input_dict.items())[0] - value = value[0] - - # Resolve block size (support dict-like or attr-like quant_config) + def _resolve_block_size(self, value: torch.Tensor) -> tuple[int, int]: block_size = None if self.hf_quantizer.quantization_config is not None: if isinstance(self.hf_quantizer.quantization_config, dict): @@ -823,98 +818,175 @@ def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor] block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None) if block_size is None: block_size = (value.shape[-2], value.shape[-1]) - - block_m, block_n = block_size + return tuple(block_size) + + def _quantize_one(self, key: str, value: torch.Tensor) -> dict[str, torch.Tensor]: + # Pass through tensors that aren't tileable (1D norms / biases, or shapes + # that don't divide cleanly by the configured block) — they were never + # FP8-quantized on the load side, so the reverse op shouldn't touch them. + if value.ndim < 2: + return {key: value} + block_m, block_n = self._resolve_block_size(value) rows, cols = value.shape[-2], value.shape[-1] - - # Enforce exact tiling like your original if rows % block_m != 0 or cols % block_n != 0: - raise ValueError( - f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}" - ) + return {key: value} # Leading dims can be empty (2D) or include num_experts/... (3D+) leading_shape = value.shape[:-2] rows_tiles = rows // block_m cols_tiles = cols // block_n - original_shape = value.shape value_fp32 = value.to(torch.float32) - # Reshape to (..., rows_tiles, block_m, cols_tiles, block_n) reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n) - - # Per-tile max-abs over the block dims - # dims: block_m is at -3, block_n is at -1 after the reshape + # Per-tile max-abs over the block dims (block_m at -3, block_n at -1) max_abs = reshaped.abs().amax(dim=(-3, -1)) safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs)) - - # Tile scale (we store inverse scale like your Linear: weight_scale_inv) + # We store inverse scale to match the upstream ``weight_scale_inv`` convention scales = _FP8_MAX / safe_max_abs scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable - - # Broadcast scales back over the block dims and quantize - # max_abs/scales shape: (..., rows_tiles, cols_tiles) - scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1) + # Broadcast scales over the block dims and quantize + scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # (..., rows_tiles, 1, cols_tiles, 1) scaled = reshaped * scales_broadcast - quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) - quantized = quantized.reshape(original_shape) + inv_scales = (1.0 / scales).to(torch.float32) + scale_key = key.rsplit(".", 1)[0] + ".weight_scale_inv" if key.endswith("weight") else key + "_scale_inv" + return {key: quantized, scale_key: inv_scales} - inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles) - if target_keys.endswith("weight"): - scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv" - else: - scale_key = target_keys + "_scale_inv" + def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]: + # Quantize every (key, tensor) entry in the dict. Single-tensor case (legacy + # callers that pass one key) and multi-tensor case (reverse of an expert + # ``MergeModulelist`` that emits one key per expert) are handled the same way. + result: dict[str, torch.Tensor] = {} + for key, value in input_dict.items(): + tensor = value[0] if isinstance(value, list) else value + result.update(self._quantize_one(key, tensor)) + return result - # Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts) - return { - target_keys: quantized, - scale_key: inv_scales, - } + @property + def reverse_op(self) -> "ConversionOps": + return Fp8Dequantize(self.hf_quantizer) class Fp8Dequantize(ConversionOps): - """Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor.""" + """Dequantize FP8 weights using their per-block ``weight_scale_inv``. + + Designed to run as the *first* op in any :class:`WeightConverter` chain when + loading with ``dequantize=True`` — :meth:`update_weight_conversions` on the + FP8 quantizer attaches it to each existing model-specific converter so that + per-expert (weight, scale) pairs are folded into full-precision tensors before + the chain's merge / concat ops collapse the per-expert structure. + + Pattern semantics + Input ``input_dict`` carries one entry per source pattern; each value is a + list of tensors (one per ``*`` match). For every weight pattern that has a + sibling ``*.weight_scale_inv`` pattern in the dict, this op pairs them up by + index, dequantizes per-pair, and emits the dequantized list under the + original *weight* key. Scale entries are dropped from the output so the + remaining ops only see weights. + """ def __init__(self, hf_quantizer): self.hf_quantizer = hf_quantizer + def _scale_pattern_for(self, weight_pattern: str) -> str: + # Strip the optional ``$`` regex anchor so we can match the underlying name. + anchored = weight_pattern.endswith("$") + base = weight_pattern[:-1] if anchored else weight_pattern + if base.endswith(".weight"): + scale = base[: -len(".weight")] + ".weight_scale_inv" + elif base == "weight": + scale = "weight_scale_inv" + else: + scale = base + "_scale_inv" + return scale + "$" if anchored else scale + + # E2M1 (FP4) value table — checkpoints sometimes ship MoE experts as packed FP4 + # (two e2m1 nibbles per int8 byte), so the "weight" dtype lands as ``int8`` / + # ``float4_e2m1fn_x2`` and we have to unpack before applying the scale grid. + _FP4_E2M1_LUT = (0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0) + + def _unpack_fp4(self, packed: torch.Tensor) -> torch.Tensor: + """Two ``e2m1`` FP4 values per byte → float32 tensor twice as wide on the last dim.""" + lut = torch.tensor(self._FP4_E2M1_LUT, dtype=torch.float32, device=packed.device) + u8 = packed.contiguous().view(torch.uint8) + low = (u8 & 0xF).long() + high = ((u8 >> 4) & 0xF).long() + unpacked = torch.stack([lut[low], lut[high]], dim=-1) + return unpacked.reshape(*packed.shape[:-1], 2 * packed.shape[-1]) + + def _dequantize_one(self, quantized: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: + # FP4 path: int8 / float4_e2m1fn_x2 stores two nibbles per byte. Unpack to fp32 + # first so the rest of the routine sees a normal (rows, cols) float matrix. + fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) + if quantized.dtype == torch.int8 or (fp4_dtype is not None and quantized.dtype == fp4_dtype): + quantized_fp32 = self._unpack_fp4(quantized) + else: + quantized_fp32 = quantized.to(torch.float32) + rows, cols = quantized_fp32.shape[-2:] + # Derive block size from the scale grid rather than the global config: MoE experts + # ship MXFP4 with a ``[1, 32]`` block, dense linears ship FP8 with ``[128, 128]``, + # and the same dequant has to handle both within one checkpoint. + scale_rows, scale_cols = scales.shape[-2:] + if rows % scale_rows or cols % scale_cols: + raise ValueError( + f"Weight shape ({rows}, {cols}) not divisible by scale grid ({scale_rows}, {scale_cols})." + ) + block_m = rows // scale_rows + block_n = cols // scale_cols + # ``ue8m0`` (``float8_e8m0fnu``) scales have no CUDA ``mul`` kernel, and casting + # the FP8 weight to that dtype loses precision. Promote both sides to fp32 for + # the math; emit in the scales' dtype when it's a real float, otherwise bf16. + out_dtype = scales.dtype if scales.dtype.is_floating_point and scales.element_size() >= 2 else torch.bfloat16 + original_shape = quantized_fp32.shape + q = quantized_fp32.reshape(-1, scale_rows, block_m, scale_cols, block_n) + s = scales.to(torch.float32).reshape(-1, scale_rows, scale_cols).unsqueeze(-1).unsqueeze(2) + return (q * s).to(out_dtype).reshape(original_shape) + def convert( self, - input_dict: dict[str, torch.Tensor], + input_dict: dict[str, list[torch.Tensor] | torch.Tensor], full_layer_name: str | None = None, **kwargs, - ) -> dict[str, torch.Tensor]: - if len(input_dict) < 2: - # case where we only got weights, need to check for "weight$" - return {full_layer_name: input_dict["weight$"]} - - quantized = input_dict["weight$"][0] - scales = input_dict["weight_scale_inv"][0] - - rows, cols = quantized.shape[-2:] - block_size = self.hf_quantizer.quantization_config.weight_block_size - if block_size is None: - block_size = (quantized.shape[-2], quantized.shape[-1]) - - block_m, block_n = block_size - - if rows % block_m != 0 or cols % block_n != 0: - raise ValueError( - f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})." - ) - quantized = quantized.to(scales.dtype) - reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) - expanded_scales = scales.reshape(-1, rows // block_m, cols // block_n) - expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) - dequantized = reshaped * expanded_scales - - return { - full_layer_name: dequantized.reshape(quantized.shape), - } + ) -> dict[str, list[torch.Tensor] | torch.Tensor]: + # Backward-compatible single-tensor path (the legacy fallback converter declares + # ``["weight$", "weight_scale_inv", "activation_scale"]`` and produces a single + # ``weight`` target). Also handles the no-scale case (e.g. RMSNorm weights that + # match ``weight$`` but ship no ``weight_scale_inv`` alongside). + if "weight$" in input_dict: + quantized = input_dict["weight$"] + quantized = quantized[0] if isinstance(quantized, list) else quantized + if "weight_scale_inv" in input_dict: + scales = input_dict["weight_scale_inv"] + scales = scales[0] if isinstance(scales, list) else scales + return {full_layer_name: self._dequantize_one(quantized, scales)} + return {full_layer_name: quantized} + + # Generic chain path: dequantize every weight pattern that has a sibling scale. + result: dict[str, list[torch.Tensor] | torch.Tensor] = {} + for key, value in input_dict.items(): + if "activation_scale" in key or "weight_scale_inv" in key: + continue # consumed by the dequant; drop from the chain + scale_key = self._scale_pattern_for(key) + if scale_key not in input_dict: + # No scale to apply (e.g. unrelated entry) — pass through untouched. + result[key] = value + continue + weights = value if isinstance(value, list) else [value] + scales = input_dict[scale_key] + scales = scales if isinstance(scales, list) else [scales] + if len(weights) != len(scales): + raise ValueError( + f"Fp8Dequantize: weight/scale count mismatch for {key} " + f"({len(weights)} weights vs {len(scales)} scales)." + ) + result[key] = [self._dequantize_one(w, s) for w, s in zip(weights, scales)] + return result @property def reverse_op(self) -> "ConversionOps": - return _IdentityOp() + # Round-trip: dequantize on load -> re-quantize on save, so the saved + # checkpoint preserves the FP8 format (weight + per-block ``weight_scale_inv``) + # whether the in-memory state stayed quantized or was dequantized for compute. + return Fp8Quantize(self.hf_quantizer) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 45e43fdaf3aa..47306a44b810 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -1414,6 +1414,10 @@ def create_chunked_causal_mask( "full_attention": create_causal_mask, "sliding_attention": create_sliding_window_causal_mask, "chunked_attention": create_chunked_causal_mask, + # V4 attention types all share the sliding-window causal mask; the long-range + # branch's compressed segment is appended to keys after the mask is built. + "compressed_sparse_attention": create_sliding_window_causal_mask, + "heavily_compressed_attention": create_sliding_window_causal_mask, } diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b041964bbdfc..31d59d2f4862 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2371,7 +2371,7 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)): if getattr(module, "weight", None) is not None: - init.normal_(module.weight, mean=0.0, std=std) + init.normal_(module.weight.float(), mean=0.0, std=std) if module.bias is not None: init.zeros_(module.bias) elif isinstance(module, nn.Embedding): diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 0e4d55350828..078e0eef732e 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -93,6 +93,7 @@ from .decision_transformer import * from .deepseek_v2 import * from .deepseek_v3 import * + from .deepseek_v4 import * from .deepseek_vl import * from .deepseek_vl_hybrid import * from .deformable_detr import * diff --git a/src/transformers/models/auto/auto_mappings.py b/src/transformers/models/auto/auto_mappings.py index 6fc8e7de6f3c..e6a863c1af87 100644 --- a/src/transformers/models/auto/auto_mappings.py +++ b/src/transformers/models/auto/auto_mappings.py @@ -122,6 +122,7 @@ ("decision_transformer", "DecisionTransformerConfig"), ("deepseek_v2", "DeepseekV2Config"), ("deepseek_v3", "DeepseekV3Config"), + ("deepseek_v4", "DeepseekV4Config"), ("deepseek_vl", "DeepseekVLConfig"), ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"), ("deformable_detr", "DeformableDetrConfig"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 06998e9f02df..514e502d0e8d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -113,6 +113,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("decision_transformer", "DecisionTransformerModel"), ("deepseek_v2", "DeepseekV2Model"), ("deepseek_v3", "DeepseekV3Model"), + ("deepseek_v4", "DeepseekV4Model"), ("deepseek_vl", "DeepseekVLModel"), ("deepseek_vl_hybrid", "DeepseekVLHybridModel"), ("deformable_detr", "DeformableDetrModel"), @@ -637,6 +638,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("dbrx", "DbrxForCausalLM"), ("deepseek_v2", "DeepseekV2ForCausalLM"), ("deepseek_v3", "DeepseekV3ForCausalLM"), + ("deepseek_v4", "DeepseekV4ForCausalLM"), ("diffllama", "DiffLlamaForCausalLM"), ("doge", "DogeForCausalLM"), ("dots1", "Dots1ForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 6d0adc8473a6..8f3c7bcc6069 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -349,6 +349,7 @@ "chatlm", "deepseek_v2", "deepseek_v3", + "deepseek_v4", "deepseek_vl", "deepseek_vl_hybrid", "deepseek_vl_v2", diff --git a/src/transformers/models/deepseek_v4/__init__.py b/src/transformers/models/deepseek_v4/__init__.py new file mode 100644 index 000000000000..fe0228917078 --- /dev/null +++ b/src/transformers/models/deepseek_v4/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_deepseek_v4 import * + from .modeling_deepseek_v4 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py new file mode 100644 index 000000000000..0cad12e4022e --- /dev/null +++ b/src/transformers/models/deepseek_v4/configuration_deepseek_v4.py @@ -0,0 +1,285 @@ +# Copyright 2026 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +DEEPSEEK_V4_LAYER_TYPES = ( + "sliding_attention", + "compressed_sparse_attention", + "heavily_compressed_attention", +) + + +_COMPRESS_RATIO_TO_LAYER_TYPE = { + 0: "sliding_attention", + 4: "compressed_sparse_attention", + 128: "heavily_compressed_attention", +} + + +DEEPSEEK_V4_MLP_LAYER_TYPES = ("hash_moe", "moe") + + +@auto_docstring(checkpoint="deepseek-ai/DeepSeek-V4-Flash-Base") +@strict +class DeepseekV4Config(PreTrainedConfig): + r""" + DeepSeek-V4's hybrid attention follows the paper (Section 2.3): every block is one + of three attention types — *Full Attention* (sliding-window only), *Compressed + Sparse Attention* (CSA, Section 2.3.1) and *Heavily Compressed Attention* (HCA, + Section 2.3.2). CSA compresses the KV cache by ``compress_rate_csa`` (m=4 in V4- + Flash/Pro) and selects ``index_topk`` blocks per query via the Lightning Indexer; + HCA applies a much heavier compression of ``compress_rate_hca`` (m'=128) and + skips sparse selection. Both branches add a small uncompressed sliding-window + branch for fine-grained locality. + + layer_types (`list[str]`): Per-layer attention schedule with values from + ``{"compressed_sparse_attention", "heavily_compressed_attention"}``. + V4-Pro default: 2× HCA bootstrap + interleaved CSA / HCA. + compress_rates (`dict[str, int]`): Per-layer-type compression rate. Default + ``{"compressed_sparse_attention": 4, "heavily_compressed_attention": 128}`` + (m=4 for CSA, m'=128 for HCA, paper §2.3.1 / §2.3.2). BC: configs that ship + ``compress_rate_csa`` / ``compress_rate_hca`` as top-level kwargs are folded + in at ``__post_init__`` time. + rope_theta (`float`): RoPE base for the main self-attention rotary. + compress_rope_theta (`float`): RoPE base for the compressed branches (paired with + ``rope_scaling`` for YaRN). + partial_rotary_factor (`float`, *optional*): Fraction of head_dim that gets RoPE. + Defaults to ``qk_rope_head_dim / head_dim`` so cos/sin sizes to ``qk_rope_head_dim``. + hc_mult (`int`): Manifold-Constrained Hyper-Connection (mHC) expansion factor n_hc + (always active; Section 2.2). + hc_sinkhorn_iters (`int`): Sinkhorn-Knopp iterations t_max for the mHC residual + mapping projection onto doubly-stochastic matrices. + hc_eps (`float`): Numerical floor for the Sinkhorn-Knopp normalization. + mlp_layer_types (`list[str]`): Per-layer MoE schedule with values from + ``{"hash_moe", "moe"}``. ``hash_moe`` routes via a frozen + ``tid2eid[input_ids]`` lookup (paper §2.1, "Hash-MoE bootstrap"); ``moe`` + is the standard top-k routed MoE. Default: 3× ``hash_moe`` then ``moe`` + for the rest. BC: legacy configs that ship ``num_hash_layers`` as a + top-level kwarg are folded in at ``__post_init__`` time. + scoring_func (`str`): Router activation — ``sqrtsoftplus``, ``softmax``, or ``sigmoid``. + swiglu_limit (`float`): Clip routed experts' gate/up pre-activations. + sliding_window (`int`): Local window size n_win used in every attention block's + sliding-window branch. + o_groups (`int`): Number of head-groups g in the grouped output projection + (paper §2.3.1, "Grouped Output Projection"). + o_lora_rank (`int`): Per-group intermediate dim d_g in the grouped output projection. + index_n_heads (`int`): Number of indexer query heads n_h^I (paper §2.3.1, eq. 14). + index_head_dim (`int`): Indexer head dim c^I (paper §2.3.1). + index_topk (`int`): Number of compressed entries per query the Lightning Indexer + keeps via top-k (paper §2.3.1, eq. 17). + num_nextn_predict_layers (`int`): MTP layer count in the upstream checkpoint + (not instantiated here). + """ + + model_type = "deepseek_v4" + keys_to_ignore_at_inference = ["past_key_values"] + # ``num_local_experts`` is the standard MoE attr name (read by FP8 / TP integrations); + # ``intermediate_size`` is what :class:`LlamaMLP` reads for the shared expert width + # — V4 only ships ``moe_intermediate_size`` so we route the read through. + attribute_map = { + "num_local_experts": "n_routed_experts", + "intermediate_size": "moe_intermediate_size", + } + + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + base_model_tp_plan = { + # q_a_proj / kv_proj outputs feed RMSNorms (q_norm / kv_norm) that normalise + # across the full output dim — sharding the output would break the norm. Only + # q_b_proj is colwise-sharded (per-head split is safe: q_head_norm is per-head), + # and o_b_proj is rowwise (input-dim sharded). o_a_proj is a GroupedLinear + # whose forward uses ``torch.bmm``; the standard TP wrappers don't handle bmm. + "layers.*.self_attn.q_b_proj": "colwise", + "layers.*.self_attn.o_b_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + } + + vocab_size: int = 129280 + hidden_size: int = 4096 + moe_intermediate_size: int = 2048 + num_hidden_layers: int = 43 + num_attention_heads: int = 64 + num_key_value_heads: int = 1 + head_dim: int = 512 + q_lora_rank: int = 1024 + default_partial_rotary_factor = 64 / 512 # ``qk_rope_head_dim`` (64) / ``head_dim`` (512) + num_experts_per_tok: int = 6 + n_routed_experts: int = 256 + n_shared_experts: int = 1 + scoring_func: str = "sqrtsoftplus" + norm_topk_prob: bool = True + routed_scaling_factor: float = 1.5 + max_position_embeddings: int = 1048576 + rope_theta: float | int = 10000.0 + + layer_types: list[str] | None = None + compress_rates: dict | None = None + default_compress_rates = {"compressed_sparse_attention": 4, "heavily_compressed_attention": 128} + compress_rope_theta: float | int = 160000.0 + hc_mult: int = 4 + hc_sinkhorn_iters: int = 20 + hc_eps: float = 1.0e-6 + mlp_layer_types: list[str] | None = None + default_num_hash_layers = 3 + swiglu_limit: float = 10.0 + sliding_window: int = 128 + o_groups: int = 8 + o_lora_rank: int = 1024 + index_n_heads: int = 64 + index_head_dim: int = 128 + index_topk: int = 512 + num_nextn_predict_layers: int = 1 + + output_router_logits: bool = False + router_aux_loss_coef: float = 0.001 + router_jitter_noise: float = 0.0 + + hidden_act: str = "silu" + initializer_range: float = 0.02 + rms_norm_eps: float = 1.0e-6 + use_cache: bool = True + pad_token_id: int | None = None + bos_token_id: int | None = 0 + eos_token_id: int | list[int] | None = 1 + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + partial_rotary_factor: float | None = None + attention_bias: bool = False + mlp_bias: bool = False + attention_dropout: float = 0.0 + + def validate_layer_type(self): + """V4 narrows the global ``ALLOWED_LAYER_TYPES`` to the three attention-block + types and two MLP-block types it actually ships with, on top of the standard + length / type-membership checks. + """ + if self.num_hidden_layers is None: + return + for name, types, allowed in ( + ("layer_types", self.layer_types, DEEPSEEK_V4_LAYER_TYPES), + ("mlp_layer_types", self.mlp_layer_types, DEEPSEEK_V4_MLP_LAYER_TYPES), + ): + if types is None: + continue + if len(types) != self.num_hidden_layers: + raise ValueError( + f"`num_hidden_layers` ({self.num_hidden_layers}) must equal `len({name})` ({len(types)})." + ) + bad = [t for t in types if t not in allowed] + if bad: + raise ValueError(f"`{name}` entries must be one of {allowed} for DeepSeek-V4; got {bad}.") + + def _apply_legacy_kwargs(self, kwargs: dict) -> dict: + """Strip and stash legacy V4 kwargs that older checkpoints / configs ship under + their original V3-flavoured names. The values are kept on the instance so + ``__post_init__`` can fold them into the new fields after the parent's init has + run, and the kwargs dict is returned cleaned for ``PreTrainedConfig.__post_init__``. + """ + self._legacy_compress_ratios = kwargs.pop("compress_ratios", None) + self._legacy_compress_rate_csa = kwargs.pop("compress_rate_csa", None) + self._legacy_compress_rate_hca = kwargs.pop("compress_rate_hca", None) + self._legacy_num_hash_layers = kwargs.pop("num_hash_layers", None) + # ``qk_rope_head_dim`` isn't a config field — it's derived from + # ``partial_rotary_factor * head_dim`` and only set as a runtime attribute. + self._legacy_qk_rope_head_dim = kwargs.pop("qk_rope_head_dim", None) + return kwargs + + def _resolve_compress_rates(self) -> None: + if self.compress_rates is None: + self.compress_rates = dict(self.default_compress_rates) + if self._legacy_compress_rate_csa is not None: + self.compress_rates["compressed_sparse_attention"] = self._legacy_compress_rate_csa + if self._legacy_compress_rate_hca is not None: + self.compress_rates["heavily_compressed_attention"] = self._legacy_compress_rate_hca + + def _resolve_layer_types(self) -> None: + n = self.num_hidden_layers + if self.layer_types is None and self._legacy_compress_ratios is not None: + # Translate the V4 checkpoint's per-layer integer ``compress_ratios`` into the + # named ``layer_types`` schedule (0 = sliding-only, 4 = CSA, 128 = HCA). + self.layer_types = [_COMPRESS_RATIO_TO_LAYER_TYPE[r] for r in self._legacy_compress_ratios] + if self.layer_types is None: + # V4-Pro default: two HCA bootstrap layers, then CSA / HCA interleaved. + interleave = [ + "compressed_sparse_attention" if i % 2 else "heavily_compressed_attention" + for i in range(max(n - 2, 0)) + ] + head = ["heavily_compressed_attention"] * min(n, 2) + self.layer_types = head + interleave + self.layer_types = list(self.layer_types[:n]) + + def _resolve_mlp_layer_types(self) -> None: + n = self.num_hidden_layers + if self.mlp_layer_types is None: + n_hash = ( + self._legacy_num_hash_layers + if self._legacy_num_hash_layers is not None + else self.default_num_hash_layers + ) + self.mlp_layer_types = ["hash_moe"] * min(n, n_hash) + ["moe"] * max(0, n - n_hash) + self.mlp_layer_types = list(self.mlp_layer_types[:n]) + + def _resolve_partial_rotary_factor(self) -> None: + if self.partial_rotary_factor is None: + self.partial_rotary_factor = ( + self._legacy_qk_rope_head_dim / self.head_dim + if self._legacy_qk_rope_head_dim is not None + else self.default_partial_rotary_factor + ) + # Runtime-only attribute; never declared as a dataclass field. + self.qk_rope_head_dim = int(self.head_dim * self.partial_rotary_factor) + + def _resolve_rope_parameters(self) -> None: + """Normalize ``rope_parameters`` into a per-rope-type dict + ``{"main": {...}, "compress": {...}}`` (Gemma3 pattern; keys are *rope-type* + labels, unrelated to ``layer_types``). Idempotent across save/load. + + By the time we get here :class:`PreTrainedConfig` has already run + :meth:`RotaryEmbeddingConfigMixin.convert_rope_params_to_dict`, which folds the + checkpoint's legacy top-level ``rope_scaling`` block into ``self.rope_parameters`` + as a flat dict (``rope_type``, ``factor``, YaRN params, …). We just split that + flat dict into the two rope-type buckets — the only difference between the two + is the ``rope_theta`` base (main attention uses ``rope_theta=10000``; the + compressor / indexer uses ``compress_rope_theta=160000``). + """ + rp = self.rope_parameters or {} + if isinstance(rp.get("main"), dict) and isinstance(rp.get("compress"), dict): + self.rope_parameters = {"main": rp["main"], "compress": rp["compress"]} + return + base = {k: v for k, v in rp.items() if k not in ("main", "compress")} + base.setdefault("rope_theta", self.rope_theta) + base["partial_rotary_factor"] = self.partial_rotary_factor + base.setdefault("rope_type", "default") + main = dict(base) + compress = {**base, "rope_theta": self.compress_rope_theta} + self.rope_parameters = {"main": main, "compress": compress} + + def __post_init__(self, **kwargs): + kwargs = self._apply_legacy_kwargs(kwargs) + PreTrainedConfig.__post_init__(self, **kwargs) + self._resolve_compress_rates() + self._resolve_layer_types() + self._resolve_mlp_layer_types() + self._resolve_partial_rotary_factor() + self._resolve_rope_parameters() + + +__all__ = ["DeepseekV4Config"] diff --git a/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py new file mode 100644 index 000000000000..50827c0a92e1 --- /dev/null +++ b/src/transformers/models/deepseek_v4/modeling_deepseek_v4.py @@ -0,0 +1,1610 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseek_v4/modular_deepseek_v4.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseek_v4.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowLayer +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub +from ...masking_utils import create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from .configuration_deepseek_v4 import DeepseekV4Config + + +@use_kernel_forward_from_hub("RMSNorm") +class DeepseekV4RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + DeepseekV4RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DeepseekV4UnweightedRMSNorm(nn.Module): + """RMSNorm without a learned weight — applied per-head to Q after ``q_b_proj`` + in :class:`DeepseekV4Attention`. Matches the V4-Flash reference's ``inference/ + model.py:498`` rescale ``q *= rsqrt(mean(q**2) + eps)``; without it attention + scores end up at the wrong scale and the model collapses to a single repeated + token within a handful of layers. + """ + + def __init__(self, eps: float = 1.0e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.float().square().mean(-1, keepdim=True) + self.eps).to(x.dtype) + + +class DeepseekV4RotaryEmbedding(nn.Module): + """Multi-layer-type rotary embedding (Laguna pattern: partial rotary on top of + Gemma3's per-layer-type buffers), specialised for V4's *interleaved* RoPE. Holds + two ``inv_freq`` buffers — ``"main"`` for self-attention (``rope_theta``) and + ``"compress"`` for the Compressor / Indexer (``compress_rope_theta``); both + honour ``partial_rotary_factor`` so cos/sin sizes to ``qk_rope_head_dim``. + ``forward(x, position_ids, layer_type=...)`` returns the half-sized cos/sin + directly — interleaved RoPE rotates pairs ``(x[2i], x[2i+1])`` so we want one + ``θ_i`` per pair, *not* the end-to-end duplicated table half-split RoPE needs. + + The ``layer_types`` here are the *rope* layer types (``"main"`` / ``"compress"``), + keys of ``config.rope_parameters``. They are unrelated to ``config.layer_types``, + which lists the per-decoder-block attention type. + """ + + inv_freq: torch.Tensor # fix linting for `register_buffer` + + # Class-level rather than ``list(set(config.layer_types))`` (gemma3's pattern): + # V4's rope keys ``"main"`` / ``"compress"`` are *orthogonal* to the per-block + # attention types in ``config.layer_types`` — every attention block uses ``main``, + # every compressor / indexer uses ``compress``, regardless of which of the three + # block types the layer is. + layer_types = ("main", "compress") + + def __init__(self, config: "DeepseekV4Config", device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + self.rope_type = {} + for layer_type in set(self.layer_types): + params = config.rope_parameters.get(layer_type) + if params is None: + continue + self.rope_type[layer_type] = params.get("rope_type", "default") + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + inv_freq, scaling = rope_init_fn(config, device, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", scaling) + + @staticmethod + def compute_default_rope_parameters( + config: DeepseekV4Config | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + layer_type (`str`, *optional*): + The current layer type if the model has different RoPE parameters per type. + Should not be used unless `config.layer_types is not None` + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters[layer_type]["rope_theta"] + # key difference to gemma3: partial rope + partial_rotary_factor = config.rope_parameters[layer_type].get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids, layer_type=None): + # Interleaved RoPE: one ``θ_i`` per pair (``rope_head_dim // 2`` entries), + # no end-to-end duplication. Same shape as ``inv_freq @ position_ids``. + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + cos = freqs.cos() * attention_scaling + sin = freqs.sin() * attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def _sliding_kv_update( + cache_layer: "DynamicSlidingWindowLayer", key_states: torch.Tensor, value_states: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Shared sliding-window K=V update body. V4 uses shared-KV MQA, so ``keys`` and + ``values`` point to the same storage on every layer; both V4 cache layer types + (HCA / CSA) call this from their ``update``.""" + if not cache_layer.is_initialized: + cache_layer.lazy_initialization(key_states, value_states) + cache_layer.values = cache_layer.keys + cache_layer.cumulative_length += key_states.shape[-2] + full = torch.cat([cache_layer.keys, key_states], dim=-2) + cache_layer.keys = full[:, :, -cache_layer.sliding_window + 1 :, :] + cache_layer.values = cache_layer.keys + return full, full + + +def _update_window_buffer( + buffer_kv: torch.Tensor | None, + buffer_gate: torch.Tensor | None, + kv: torch.Tensor, + gate: torch.Tensor, + compress_rate: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Merge a still-buffered tail with freshly projected ``(kv, gate)`` and split off + the longest window-aligned chunk. Used by both the compressor- and indexer-side + window buffers; tokens past the last full window stay in the buffer until the + next call rounds them out to a multiple of ``compress_rate``.""" + if buffer_kv is not None and buffer_kv.shape[1]: + kv = torch.cat([buffer_kv, kv], dim=1) + gate = torch.cat([buffer_gate, gate], dim=1) + usable = (kv.shape[1] // compress_rate) * compress_rate + return kv[:, :usable], gate[:, :usable], kv[:, usable:], gate[:, usable:] + + +def _append_to_pool(pool: torch.Tensor | None, new_pooled: torch.Tensor) -> torch.Tensor: + """Append freshly emitted compressed entries to a running pool, returning the + full pool (or an empty tensor if nothing has been pooled yet).""" + if new_pooled.shape[1] > 0: + return new_pooled if pool is None else torch.cat([pool, new_pooled], dim=1) + if pool is None: + return new_pooled.new_zeros((new_pooled.shape[0], 0, new_pooled.shape[-1])) + return pool + + +class DeepseekV4HCACache(DynamicSlidingWindowLayer): + """Cache layer for HCA blocks (paper §2.3.2). Holds the long-range compressor's + buffer / pool / count on top of the sliding-window K=V branch. HCA uses + *non-overlapping* windows, so there is **no** overlap state, and HCA has **no** + indexer either. + + Fields on top of :class:`DynamicSlidingWindowLayer`: + + * ``compressor_pool`` — the running list of compressed KV entries emitted so + far (one per ``compress_rate_hca`` source tokens; the long-range KVs the + attention concatenates onto its sliding-window keys / values). + * ``compressor_buffer_kv`` / ``compressor_buffer_gate`` — source tokens that + arrived between two full windows; once the buffer hits ``compress_rate_hca`` + tokens the compressor closes a window, emits one pooled entry, and drains + the buffer. + * ``compressor_pool_count`` — number of compressed entries emitted so far, + so ``compressor_pool_count * compress_rate_hca`` is the absolute position + of the *next* window's first source token. + + The class-level ``layer_type`` auto-registers this class with + :data:`LAYER_TYPE_CACHE_MAPPING` so :class:`DynamicCache` builds it on its own + when ``config.layer_types[i] == "heavily_compressed_attention"``. + """ + + layer_type = "heavily_compressed_attention" + + def __init__(self, config: "DeepseekV4Config"): + super().__init__(config) + self.compress_rate = config.compress_rates["heavily_compressed_attention"] + self.compressor_buffer_kv: torch.Tensor | None = None + self.compressor_buffer_gate: torch.Tensor | None = None + self.compressor_pool: torch.Tensor | None = None + self.compressor_pool_count = 0 + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs): + return _sliding_kv_update(self, key_states, value_states) + + def update_compressor(self, kv: torch.Tensor, gate: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """Merge the freshly projected ``(kv, gate)`` (paper §2.3.2 eqs. 20–21: + ``C = H·W^{KV}``, ``Z = H·W^Z``) with the buffered tail from prior calls and + return the longest window-aligned chunk that's ready to pool, plus the + absolute source-token position of that chunk's first window. The returned + chunk is softmax-pooled by the compressor with ``position_bias`` to emit one + compressed entry per window of ``compress_rate_hca`` tokens (eqs. 22–23).""" + first_pool_position = self.compressor_pool_count * self.compress_rate + chunk_kv, chunk_gate, self.compressor_buffer_kv, self.compressor_buffer_gate = _update_window_buffer( + self.compressor_buffer_kv, self.compressor_buffer_gate, kv, gate, self.compress_rate + ) + return chunk_kv, chunk_gate, first_pool_position + + def update_compressor_pool(self, new_pooled: torch.Tensor) -> torch.Tensor: + """Append freshly emitted compressed entries to ``compressor_pool`` + (``C^{Comp}``, paper §2.3.2 eq. 23) and return the full pool. Bumps + ``compressor_pool_count`` so the next ``update_compressor`` call knows the + absolute source-token position of its first window.""" + self.compressor_pool = _append_to_pool(self.compressor_pool, new_pooled) + self.compressor_pool_count += new_pooled.shape[1] + return self.compressor_pool + + +class DeepseekV4CSACache(DynamicSlidingWindowLayer): + """Cache layer for CSA blocks (paper §2.3.1). Holds two parallel sets of + buffer / pool / count / overlap state on top of the sliding-window K=V branch: + + * **compressor side** — the main-branch ``head_dim`` pool (the long-range KVs + the attention concatenates after top-k indexer selection). + * **indexer side** — the Lightning Indexer's smaller ``index_head_dim`` pool + (the keys ``K^{IComp}`` that queries score against to pick the top-k blocks, + eqs. 14–17). Kept separate from the compressor pool because the head dim + differs. + + Both sides use **overlapping** windows of stride ``compress_rate_csa`` and width + ``2 * compress_rate_csa`` (paper §2.3.1), so each side also keeps an + ``*_overlap_kv`` / ``*_overlap_gate`` pair holding the last full window's + projected ``(kv, gate)`` so the next forward call's first window can stitch in + its low-channel slice as the prior contribution. + + The class-level ``layer_type`` auto-registers this class with + :data:`LAYER_TYPE_CACHE_MAPPING` so :class:`DynamicCache` builds it on its own + when ``config.layer_types[i] == "compressed_sparse_attention"``. + """ + + layer_type = "compressed_sparse_attention" + + def __init__(self, config: "DeepseekV4Config"): + super().__init__(config) + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + # Compressor side + self.compressor_buffer_kv: torch.Tensor | None = None + self.compressor_buffer_gate: torch.Tensor | None = None + self.compressor_pool: torch.Tensor | None = None + self.compressor_pool_count = 0 + self.compressor_overlap_kv: torch.Tensor | None = None + self.compressor_overlap_gate: torch.Tensor | None = None + # Indexer side (parallel state at ``index_head_dim``) + self.indexer_buffer_kv: torch.Tensor | None = None + self.indexer_buffer_gate: torch.Tensor | None = None + self.indexer_pool: torch.Tensor | None = None + self.indexer_pool_count = 0 + self.indexer_overlap_kv: torch.Tensor | None = None + self.indexer_overlap_gate: torch.Tensor | None = None + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs): + return _sliding_kv_update(self, key_states, value_states) + + def update_compressor(self, kv: torch.Tensor, gate: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """Compressor-side window buffer (paper §2.3.1 main-branch pool, eqs. 9–12). + Same window-aligned tail-buffering as HCA, but at the CSA cadence + (``compress_rate_csa``).""" + first_pool_position = self.compressor_pool_count * self.compress_rate + chunk_kv, chunk_gate, self.compressor_buffer_kv, self.compressor_buffer_gate = _update_window_buffer( + self.compressor_buffer_kv, self.compressor_buffer_gate, kv, gate, self.compress_rate + ) + return chunk_kv, chunk_gate, first_pool_position + + def update_compressor_pool(self, new_pooled: torch.Tensor) -> torch.Tensor: + """Append freshly emitted entries to the CSA compressor pool (the + ``C^{Comp}`` running list at ``head_dim``, eqs. 11–12).""" + self.compressor_pool = _append_to_pool(self.compressor_pool, new_pooled) + self.compressor_pool_count += new_pooled.shape[1] + return self.compressor_pool + + def get_compressor_overlap(self) -> tuple[torch.Tensor | None, torch.Tensor | None]: + return self.compressor_overlap_kv, self.compressor_overlap_gate + + def set_compressor_overlap(self, kv: torch.Tensor, gate: torch.Tensor) -> None: + self.compressor_overlap_kv = kv + self.compressor_overlap_gate = gate + + def update_indexer(self, kv: torch.Tensor, gate: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """Indexer-side mirror of :meth:`update_compressor` (paper §2.3.1, "Lightning + Indexer for Sparse Selection"). Same logic at the smaller ``index_head_dim`` + — the small-head pool keys ``K^{IComp}`` (eq. 14's ``W^{IUQ}`` complement on + the key side) that the indexer scores queries against to pick the top-k + blocks (eqs. 15–17). Buffer / pool / count are kept separate from the + compressor's state because the head dim differs.""" + first_pool_position = self.indexer_pool_count * self.compress_rate + chunk_kv, chunk_gate, self.indexer_buffer_kv, self.indexer_buffer_gate = _update_window_buffer( + self.indexer_buffer_kv, self.indexer_buffer_gate, kv, gate, self.compress_rate + ) + return chunk_kv, chunk_gate, first_pool_position + + def update_indexer_pool(self, new_pooled: torch.Tensor) -> torch.Tensor: + """Append freshly emitted entries to the indexer pool ``K^{IComp}`` (paper + §2.3.1 eq. 16: the keys against which the ``q^I_t`` queries score for top-k + selection). Same cadence as the compressor pool — one entry per + ``compress_rate_csa`` source tokens — but at ``index_head_dim``.""" + self.indexer_pool = _append_to_pool(self.indexer_pool, new_pooled) + self.indexer_pool_count += new_pooled.shape[1] + return self.indexer_pool + + def get_indexer_overlap(self) -> tuple[torch.Tensor | None, torch.Tensor | None]: + return self.indexer_overlap_kv, self.indexer_overlap_gate + + def set_indexer_overlap(self, kv: torch.Tensor, gate: torch.Tensor) -> None: + self.indexer_overlap_kv = kv + self.indexer_overlap_gate = gate + + +class DeepseekV4GroupedLinear(nn.Linear): + """Block-diagonal grouped linear used by the V4 grouped output projection + (paper §2.3.1, "Grouped Output Projection"; HCA reuses the same scheme, + §2.3.2). With ``num_attention_heads = n_h`` and per-head dim ``c``, the core + attention's stacked output is ``c·n_h``-dim, which is *very* large for V4 + (V4-Flash: c=512, n_h=64 → 32768; V4-Pro: c=512, n_h=128 → 65536). A direct + ``c·n_h → hidden_size`` projection would dominate the per-token cost. + + The paper sidesteps that by splitting the n_h heads into ``g`` groups, projecting + each ``c·n_h/g``-dim group independently to a ``d_g``-dim intermediate output + (with ``d_g < c·n_h/g``), and then mixing the resulting ``g·d_g`` vector to + ``hidden_size`` through a single follow-up linear (``self_attn.o_b_proj``). This + module owns the per-group block (``self_attn.o_a_proj``). + + The ``weight`` parameter is shaped like a standard ``nn.Linear`` + (``[out_features, in_features_per_group]``) so quantizers keyed on + ``nn.Linear.weight`` still pick it up; ``forward`` does the per-group ``bmm``. + """ + + def __init__(self, in_features_per_group: int, out_features: int, n_groups: int, bias: bool = False): + super().__init__(in_features_per_group, out_features, bias=bias) + self.n_groups = n_groups + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [..., n_groups, in_features_per_group] + input_shape = x.shape[:-2] + d_in = x.shape[-1] + w = self.weight.view(self.n_groups, -1, d_in).transpose(1, 2) + x = x.reshape(-1, self.n_groups, d_in).transpose(0, 1) + y = torch.bmm(x, w).transpose(0, 1) + return y.reshape(*input_shape, self.n_groups, -1) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., 0::2] + x2 = x[..., 1::2] + return torch.stack((-x2, x1), dim=-1).flatten(-2) + + +def apply_rotary_pos_emb( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> torch.Tensor: + """V4-Flash interleaved RoPE (matches the reference's ``apply_rotary_emb`` at + ``inference/model.py:232``). Rotates the **trailing** ``2 * cos.shape[-1]`` channels + of ``x`` (V4-Flash lays each head out as ``[nope | rope]``, matching the reference's + ``x[..., -rd:]`` indexing) and leaves the leading nope channels unchanged. The + half-sized cos / sin from :class:`DeepseekV4RotaryEmbedding` are expanded to the + pair-aligned full rope dim via ``repeat_interleave`` here, and the rotation is + done in fp32 (matching the reference's precision).""" + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(unsqueeze_dim) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(unsqueeze_dim) + rope_dim = cos.shape[-1] + nope, rope = x[..., :-rope_dim], x[..., -rope_dim:] + rotated = ((rope.float() * cos) + (rotate_half(rope).float() * sin)).to(x.dtype) + return torch.cat([nope, rotated], dim=-1) + + +def _overlap_pool( + chunk_kv: torch.Tensor, + chunk_gate: torch.Tensor, + prior_kv: torch.Tensor | None, + prior_gate: torch.Tensor | None, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Expand ``[B, n_win, ratio, 2*head_dim]`` chunks into ``[B, n_win, 2*ratio, head_dim]`` + by stitching each window's *low-channel* slice onto the *high-channel* slice of the + prior window — matching the V4-Flash reference (``Compressor.overlap_transform``). + + Each pooled output thus mixes ``ratio`` *current* source tokens (high half of the + learned 2d split) with ``ratio`` *previous* source tokens (low half), so windows + have width ``2*ratio`` but stride ``ratio`` (paper §2.3.1). For window 0, the prior + half is filled with zero (kv) / ``-inf`` (gate, so its softmax weight is exactly 0), + unless ``prior_kv`` / ``prior_gate`` carry the last full window from a previous + forward call — in which case its low-channel slice slots into row ``[0, :ratio]``. + """ + batch, n_windows, ratio, _ = chunk_kv.shape + new_kv = chunk_kv.new_zeros((batch, n_windows, 2 * ratio, head_dim)) + new_gate = chunk_gate.new_full((batch, n_windows, 2 * ratio, head_dim), float("-inf")) + new_kv[:, :, ratio:] = chunk_kv[..., head_dim:] + new_gate[:, :, ratio:] = chunk_gate[..., head_dim:] + if n_windows > 1: + new_kv[:, 1:, :ratio] = chunk_kv[:, :-1, :, :head_dim] + new_gate[:, 1:, :ratio] = chunk_gate[:, :-1, :, :head_dim] + if prior_kv is not None and prior_gate is not None: + new_kv[:, 0, :ratio] = prior_kv[..., :head_dim].to(new_kv.dtype) + new_gate[:, 0, :ratio] = prior_gate[..., :head_dim].to(new_gate.dtype) + return new_kv, new_gate + + +def _rope_pool(pooled: torch.Tensor, rotary_emb: nn.Module, positions: torch.Tensor, layer_type: str) -> torch.Tensor: + """Apply RoPE to the trailing rope slice of each pooled entry at its deterministic + absolute position. Used by both the indexer pool and the HCA / CSA compressor pools.""" + cos, sin = rotary_emb(pooled, position_ids=positions, layer_type=layer_type) + return apply_rotary_pos_emb(pooled.unsqueeze(1), cos, sin).squeeze(1) + + +class DeepseekV4Indexer(nn.Module): + """Lightning Indexer (paper §2.3.1, eqs. 13–17). Used by Compressed Sparse + Attention (CSA) to pick the top-k compressed KV blocks per query. The indexer + runs its own scaled-down compressor at ``index_head_dim`` over the same windows + as the outer CSA compressor, then scores queries against the pooled keys with + ``∑_h w_{t,h} · ReLU(q_{t,h} · K^IComp_s)`` and keeps the top ``index_topk`` + indices. + + Class-attribute ``rope_layer_type`` selects which inv_freq buffer of the shared + :class:`DeepseekV4RotaryEmbedding` to use; the indexer always reads + ``"compress"`` (paired with ``compress_rope_theta``). + + The indexer has its own rotary because it applies RoPE to two sets of tensors: + + * **pool keys** at deterministic positions ``i * compress_rate + first_pool_position``, + * **queries** at the model's current ``position_ids`` (variable per forward). + + Both must use the same theta as the outer compressor (``compress_rope_theta``) so + query/key inner products are translation-invariant in the standard rope sense — if + they used different thetas the score ``q · k`` would carry a residual position- + dependent skew. We can't precompute cos/sin once at init because the query + positions vary per call, so the indexer owns a rotary embedding and calls it with + ``layer_type=self.rope_layer_type`` twice per forward (once for pool keys, once for queries). + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + self.n_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.index_topk = config.index_topk + self.softmax_scale = self.head_dim**-0.5 + self.weights_scaling = self.n_heads**-0.5 + # The indexer always pools with the CSA cadence (``compress_rate=4``), so its + # inner pool runs the same overlapping-window scheme as :class:`DeepseekV4CSACompressor` + # (paper §2.3.1) — ``coff = 2`` everywhere on the pool branch. + self.coff = 2 + self.kv_proj = nn.Linear(config.hidden_size, self.coff * self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, self.coff * self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, self.coff * self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.n_heads * self.head_dim, bias=False) + self.weights_proj = nn.Linear(config.hidden_size, self.n_heads, bias=False) + self.rotary_emb = DeepseekV4RotaryEmbedding(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> torch.LongTensor: + batch, seq_len, _ = hidden_states.shape + cache_layer = past_key_values.layers[layer_idx] if past_key_values is not None else None + + # --- Pool side: same overlapping windows as the outer CSA compressor, at index_head_dim --- + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_pool_position = kv[:, :usable], gate[:, :usable], 0 + prior_kv, prior_gate = None, None + else: + chunk_kv, chunk_gate, first_pool_position = cache_layer.update_indexer(kv, gate) + prior_kv, prior_gate = cache_layer.get_indexer_overlap() + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, self.compress_rate, -1) + chunk_gate = chunk_gate.view(batch, n_windows, self.compress_rate, -1) + self.position_bias.to( + chunk_gate.dtype + ) + if cache_layer is not None: + cache_layer.set_indexer_overlap(chunk_kv[:, -1].clone(), chunk_gate[:, -1].clone()) + chunk_kv, chunk_gate = _overlap_pool(chunk_kv, chunk_gate, prior_kv, prior_gate, self.head_dim) + # Softmax in fp32 for stability (logits in bf16/fp16 can collapse pairs that + # only differ by a small amount, especially with large window widths). + new_pooled = self.kv_norm( + (chunk_kv * chunk_gate.softmax(dim=2, dtype=torch.float32).to(chunk_kv.dtype)).sum(dim=2) + ) + positions = ( + (torch.arange(n_windows, device=new_pooled.device) * self.compress_rate + first_pool_position) + .unsqueeze(0) + .expand(batch, -1) + ) + new_pooled = _rope_pool(new_pooled, self.rotary_emb, positions, self.rope_layer_type) + else: + new_pooled = chunk_kv.new_zeros((batch, 0, self.head_dim)) + pooled_kv = new_pooled if cache_layer is None else cache_layer.update_indexer_pool(new_pooled) + + # --- Query side --- + cos_q, sin_q = self.rotary_emb(hidden_states, position_ids=position_ids, layer_type=self.rope_layer_type) + q = self.q_b_proj(q_residual).view(batch, seq_len, -1, self.head_dim).transpose(1, 2) + q = apply_rotary_pos_emb(q, cos_q, sin_q).transpose(1, 2) + + # --- Score: ReLU(q·kᵀ) * weights, then top-k --- + scores = torch.matmul(q.float(), pooled_kv.transpose(-1, -2).float().unsqueeze(1)) # [B, S, H, T] + scores = F.relu(scores) * self.softmax_scale + weights = self.weights_proj(hidden_states).float() * self.weights_scaling # [B, S, H] + index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] + topk = min(self.index_topk, pooled_kv.shape[1]) + return index_scores.topk(topk, dim=-1).indices + + +class DeepseekV4HCACompressor(nn.Module): + """Heavily Compressed Attention compressor (paper §2.3.2, eqs. 20–23). Pools + every ``compress_rate_hca`` (m'=128) source tokens into a single compressed KV + entry with **non-overlapping** windows — no overlap state, no indexer. + + The three building blocks (paper notation in parentheses): + + * **kv** = ``kv_proj(hidden_states)`` — head-dim KV projection ``C ∈ R^{n×c}`` + (eq. 20). Doubles as both key and value (shared-KV MQA). + * **gate** = ``gate_proj(hidden_states)`` — head-dim compression weights + ``Z ∈ R^{n×c}`` (eq. 21). Combined with ``position_bias`` and softmaxed per + window to produce the convex combination that mixes ``compress_rate_hca`` + source KVs into one pooled entry. + * **pool** — running list of compressed KV entries (``C^{Comp}``, eq. 23). + Lives on :class:`DeepseekV4HCACache`; the in-flight buffer of tokens that + haven't yet filled a window lives there too. + + Each closed window of m' tokens produces one pooled entry: + ``C^{Comp}_i = Σ_{j∈window} softmax(Z_j + B)_j ⊙ C_j``. RoPE on the trailing + ``rope_head_dim`` slice is applied at the deterministic absolute position + ``i * compress_rate_hca + first_pool_position`` so cross-call concatenation stays + causality-correct. Returns the running pool ``[B, 1, T, head_dim]``. + + When ``past_key_values is None`` (a checkpoint replay zeroes the cache to break + the grad-cache loop), runs in stateless single-shot mode: pool every complete + window from ``hidden_states`` and discard the remainder. + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["heavily_compressed_attention"] + self.head_dim = config.head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.kv_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV4RotaryEmbedding(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> torch.Tensor: + # ``q_residual`` / ``position_ids`` are unused — the uniform forward signature + # lets :class:`DeepseekV4Attention` call either compressor without branching. + batch, _, _ = hidden_states.shape + cache_layer = past_key_values.layers[layer_idx] if past_key_values is not None else None + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_pool_position = kv[:, :usable], gate[:, :usable], 0 + else: + chunk_kv, chunk_gate, first_pool_position = cache_layer.update_compressor(kv, gate) + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, self.compress_rate, -1) + chunk_gate = chunk_gate.view(batch, n_windows, self.compress_rate, -1) + self.position_bias.to( + chunk_gate.dtype + ) + # Softmax in fp32 for stability (logits in bf16/fp16 can collapse pairs that + # only differ by a small amount, especially with large window widths). + new_pooled = self.kv_norm( + (chunk_kv * chunk_gate.softmax(dim=2, dtype=torch.float32).to(chunk_kv.dtype)).sum(dim=2) + ) + positions = ( + (torch.arange(n_windows, device=new_pooled.device) * self.compress_rate + first_pool_position) + .unsqueeze(0) + .expand(batch, -1) + ) + new_pooled = _rope_pool(new_pooled, self.rotary_emb, positions, self.rope_layer_type) + else: + new_pooled = chunk_kv.new_zeros((batch, 0, self.head_dim)) + if cache_layer is None: + return new_pooled.unsqueeze(1) + return cache_layer.update_compressor_pool(new_pooled).unsqueeze(1) + + +class DeepseekV4CSACompressor(nn.Module): + """Compressed Sparse Attention compressor (paper §2.3.1, eqs. 9–17). Pools every + ``compress_rate_csa`` (m=4) source tokens with **overlapping** windows — stride + ``compress_rate_csa`` and effective width ``2 * compress_rate_csa`` — and runs a + Lightning Indexer on top of the pool that scores queries with + ``∑_h w_{t,h} · ReLU(q_{t,h} · K^{IComp}_s)`` to gather the top ``index_topk`` + entries per query before they reach core attention. + + Compared to :class:`DeepseekV4HCACompressor` the differences are explicit: + + * ``kv_proj`` / ``gate_proj`` / ``position_bias`` project to **2 × head_dim** (the + learned channel split — high half pools into the current window, low half + pools into the next window's overlap with this one, see :func:`_overlap_pool`). + * The cache layer's ``compressor_overlap_*`` state carries the last full + window across forward calls. + * A :class:`DeepseekV4Indexer` sub-module gathers the top-``index_topk`` pool + entries per query (paper §2.3.1, "Lightning Indexer for Sparse Selection"). + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + self.head_dim = config.head_dim + self.rope_head_dim = config.qk_rope_head_dim + # ``2 * head_dim`` because windows overlap: each pooled entry is a softmax-gated + # convex combination of ``compress_rate_csa`` *current* tokens (high-channel half) + # mixed with ``compress_rate_csa`` *previous* tokens (low-channel half). The + # learned channel split happens in :func:`_overlap_pool`. + self.kv_proj = nn.Linear(config.hidden_size, 2 * self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, 2 * self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, 2 * self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV4RotaryEmbedding(config) + self.indexer = DeepseekV4Indexer(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> torch.Tensor: + batch, seq_len, _ = hidden_states.shape + cache_layer = past_key_values.layers[layer_idx] if past_key_values is not None else None + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_pool_position = kv[:, :usable], gate[:, :usable], 0 + prior_kv, prior_gate = None, None + else: + chunk_kv, chunk_gate, first_pool_position = cache_layer.update_compressor(kv, gate) + prior_kv, prior_gate = cache_layer.get_compressor_overlap() + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, self.compress_rate, -1) + chunk_gate = chunk_gate.view(batch, n_windows, self.compress_rate, -1) + self.position_bias.to( + chunk_gate.dtype + ) + if cache_layer is not None: + # Persist the *raw* last full window (gate already biased) so the next + # forward call's first window can read its low-channel slice as prior. + cache_layer.set_compressor_overlap(chunk_kv[:, -1].clone(), chunk_gate[:, -1].clone()) + chunk_kv, chunk_gate = _overlap_pool(chunk_kv, chunk_gate, prior_kv, prior_gate, self.head_dim) + # Softmax in fp32 for stability (logits in bf16/fp16 can collapse pairs that + # only differ by a small amount, especially with large window widths). + new_pooled = self.kv_norm( + (chunk_kv * chunk_gate.softmax(dim=2, dtype=torch.float32).to(chunk_kv.dtype)).sum(dim=2) + ) + positions = ( + (torch.arange(n_windows, device=new_pooled.device) * self.compress_rate + first_pool_position) + .unsqueeze(0) + .expand(batch, -1) + ) + new_pooled = _rope_pool(new_pooled, self.rotary_emb, positions, self.rope_layer_type) + else: + new_pooled = chunk_kv.new_zeros((batch, 0, self.head_dim)) + pooled = ( + new_pooled.unsqueeze(1) + if cache_layer is None + else cache_layer.update_compressor_pool(new_pooled).unsqueeze(1) + ) + # Lightning Indexer: gather top-``index_topk`` pool entries per query. + topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] + expanded = pooled.unsqueeze(2).expand(-1, -1, seq_len, -1, -1) + idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) + return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float | int = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +COMPRESSOR_CLASSES = { + "sliding_attention": None, + "compressed_sparse_attention": DeepseekV4CSACompressor, + "heavily_compressed_attention": DeepseekV4HCACompressor, +} + + +# ----------------------------------------------------------------------------- +# Attention with sink. +# ----------------------------------------------------------------------------- + + +class DeepseekV4Attention(nn.Module): + """V4 attention block (paper §2.3). Single class for all three layer types — the + only thing that varies is the long-range branch (the ``compressor`` sub-module); + the surrounding QKV / RoPE / sink / sliding-window / output projection is + identical. The three layer types are dispatched by ``COMPRESSOR_CLASSES``: + + * ``sliding_attention``: ``compressor = None``; only the local sliding-window + K=V branch ("Full Attention"). + * ``compressed_sparse_attention``: :class:`DeepseekV4CSACompressor` — + low-compression overlapping-window pool plus a Lightning Indexer that keeps + the top-``index_topk`` pool entries per query (paper §2.3.1). + * ``heavily_compressed_attention``: :class:`DeepseekV4HCACompressor` — + high-compression non-overlapping-window pool, no indexer (paper §2.3.2). + + Block components (paper §2.3.3): + + * Shared-KV Multi-Query Attention: ``num_key_value_heads = 1``; ``kv_proj`` projects + directly to that single KV head and the same tensor is read as both key and + value. + * Partial RoPE on the first ``rope_head_dim`` of each head ("Partial Rotary + Positional Embedding"). RoPE is also applied with position ``-i`` to the + attention output's rope slice, so the contribution of each KV entry stays a + function of the *relative* distance to the query. + * RMSNorm on the queries (``q_norm``) and the compressed KV head (``kv_norm``) + right before the core attention, to keep logits bounded. + * Per-head learnable attention sink (eq. 27). + * Grouped low-rank output projection (§2.3.1, "Grouped Output Projection"): + ``g`` head-groups → ``d_g``-dim intermediate outputs through a block-diagonal + :class:`DeepseekV4GroupedLinear`, then mixed back to ``hidden_size`` by ``o_b_proj``. + * A supplementary uncompressed sliding-window KV branch of size + ``sliding_window`` ("Additional Branch of Sliding Window Attention") that + preserves local fine-grained dependencies, concatenated with the + long-range compressor's output before core attention. + """ + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.num_heads = config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads # single KV head, broadcast to all + self.head_dim = config.head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.sliding_window = config.sliding_window + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.scaling = self.head_dim**-0.5 + + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=False) + self.q_norm = DeepseekV4RMSNorm(config.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.head_dim, bias=False) + self.q_head_norm = DeepseekV4UnweightedRMSNorm(eps=config.rms_norm_eps) + self.kv_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.kv_norm = DeepseekV4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.o_a_proj = DeepseekV4GroupedLinear( + self.num_heads * self.head_dim // config.o_groups, config.o_groups * config.o_lora_rank, config.o_groups + ) + self.o_b_proj = nn.Linear(config.o_groups * config.o_lora_rank, config.hidden_size, bias=False) + self.sinks = nn.Parameter(torch.empty(self.num_heads)) + # Long-range branch dispatched by ``layer_type`` (see ``COMPRESSOR_CLASSES`` + # above). ``None`` means full-attention / sliding-only — no compressor is + # built and the layer keeps just the local sliding-window K=V branch. + compressor_cls = COMPRESSOR_CLASSES[self.layer_type] + self.compressor = compressor_cls(config) if compressor_cls is not None else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_ids: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + batch, seq_len = hidden_states.shape[:2] + cos, sin = position_embeddings + q_residual = self.q_norm(self.q_a_proj(hidden_states)) + q = self.q_b_proj(q_residual).view(batch, seq_len, -1, self.head_dim).transpose(1, 2) + q = self.q_head_norm(q) + kv = self.kv_norm(self.kv_proj(hidden_states)).view(batch, seq_len, 1, self.head_dim).transpose(1, 2) + q = apply_rotary_pos_emb(q, cos, sin) + kv = apply_rotary_pos_emb(kv, cos, sin) + # Under TP, ``q_b_proj`` is colwise-sharded so ``q`` carries + # ``num_heads / tp_size`` heads per rank, while ``kv`` (single shared head) is + # replicated. Refresh ``num_key_value_groups`` from the local ``q`` so the + # standard ``repeat_kv(key, num_key_value_groups)`` in the attention backends + # lifts the single kv head to match the rank-local query head count. + self.num_key_value_groups = q.shape[1] + + # --- Sliding-window K=V branch goes through the standard cache update --- + if past_key_values is not None: + kv, _ = past_key_values.update(kv, kv, self.layer_idx) + + if self.compressor is None: + full_kv = kv + else: + compressed_kv = self.compressor(hidden_states, q_residual, position_ids, past_key_values, self.layer_idx) + full_kv = torch.cat([kv, compressed_kv], dim=2) + + # Compressor concatenates extra long-range KV entries onto the sliding-window + # branch so ``full_kv`` is wider than the model-level mask was built for. Pad + # the mask's key dim with ``0.0`` (additive-mask convention: unmasked) so the + # compressor positions are attended. FA / SDPA backends consume the same 4D + # additive mask path here, so the pad is uniform across backends. + if attention_mask is not None and full_kv.shape[2] > attention_mask.shape[-1]: + attention_mask = F.pad(attention_mask, (0, full_kv.shape[2] - attention_mask.shape[-1]), value=0.0) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + q, + full_kv, + full_kv, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, + **kwargs, + ) + + # De-rotate the output's rope slice. V4 shares K and V (``kv_proj`` projects to a + # single tensor), so V's rope slice carries the same per-token rotation as K. + # Attention sums V-rotated values across attended positions, so the output's + # rope slice is a position-mixed content; conjugate rotation at the query + # position pulls it back into a position-independent frame before the output + # projection mixes heads. + attn_output = apply_rotary_pos_emb(attn_output.transpose(1, 2), cos, -sin).transpose(1, 2) + + grouped = attn_output.reshape(batch, seq_len, -1).view(batch, seq_len, self.config.o_groups, -1) + return self.o_b_proj(self.o_a_proj(grouped).flatten(2)), attn_weights + + +class DeepseekV4HyperConnection(nn.Module): + r""" + Manifold-Constrained Hyper-Connections + (mHC) (Xie et al., 2026) to strengthen the conventional residual connections between adjacent + Transformer blocks + + Owns the learned (``fn``, ``base``, ``scale``) + parameters that turn the incoming ``hc_mult`` residual streams into collapse / expand + weights. The decoder layer instantiates two of these (one for the attention site, + one for the mlp site). + + ASCII shape guide — ``B`` = batch, ``S`` = seq, ``H`` = hc_mult, ``D`` = hidden_size:: + + hidden_streams flatten(2) RMSNorm-rescale + F.linear(fn) + [B, S, H, D] ──────────► [B, S, H*D] ─────────────────────────────────► + mix-logits + [B, S, (2+H)*H] + │ + ┌───────────────────────────────────────┴──────────────────────────────┐ + ▼ ▼ ▼ + pre logits post logits comb logits + [B, S, H] [B, S, H] [B, S, H, H] + × scale[0] × scale[1] × scale[2] + + base[:H] + base[H:2H] + base[2H:] + σ() + eps σ() + eps σ() + eps + │ │ │ + pre post Sinkhorn(iters) + (stream collapse weights) (block-output placement) row/col normalise + │ + comb + (stream mixer) + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.hc_mult = config.hc_mult + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + self.input_norm = DeepseekV4UnweightedRMSNorm(eps=config.rms_norm_eps) + mix = (2 + self.hc_mult) * self.hc_mult + self.fn = nn.Parameter(torch.empty(mix, self.hc_mult * config.hidden_size)) + self.base = nn.Parameter(torch.empty(mix)) + # 3 = number of outputs from the mHC mapping: ``pre`` (input projection + # weights), ``post`` (sublayer output projection weights), ``comb`` (the + # H×H residual combine matrix that gets Sinkhorn-projected onto the + # doubly-stochastic manifold). Each output gets its own learned scale. + self.scale = nn.Parameter(torch.empty(3)) + + def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + project it onto the manifold of doubly stochastic matrices M. + This is achieved by the Sinkhorn-Knopp algorithm, which first applies an exponential function + ˜ + to + 𝐵𝑙 to ensure positivity, getting 𝑀(0) = exp(˜ + 𝐵𝑙), and then iteratively performs column and row + normalization: + 𝑀(𝑡) = T𝑟(T𝑐(𝑀(𝑡−1))), (8) + where T𝑟 and T𝑐 denote row and column normalization, respectively. + """ + flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) + mix = F.linear(flat, self.fn.float()) # [B, S, (2+H)*H] + pre_scale, post_scale, comb_scale = self.scale.unbind(0) + hc = self.hc_mult + pre = torch.sigmoid(mix[..., :hc] * pre_scale + self.base[:hc]) + self.hc_eps + post = torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + self.hc_eps + comb = ( + torch.sigmoid( + mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) + ) + + self.hc_eps + ) + for _ in range(self.hc_sinkhorn_iters): + comb = comb / (comb.sum(dim=-1, keepdim=True) + self.hc_eps) + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + # Collapse the ``hc_mult`` parallel streams down to a single sequence using + # the ``pre`` weights (Manifold-Constrained input projection): one weighted + # sum across the stream axis, ready for the sublayer (attn / MLP). + collapsed = (pre.unsqueeze(-1) * hidden_streams).sum(dim=2).to(hidden_streams.dtype) + return post, comb, collapsed + + +class DeepseekV4HyperHead(nn.Module): + """Final HC-stream collapse; used by ``DeepseekV4Model`` before the shared RMSNorm.""" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.hc_mult = config.hc_mult + self.input_norm = DeepseekV4UnweightedRMSNorm(eps=config.rms_norm_eps) + self.eps = config.hc_eps + self.hc_fn = nn.Parameter(torch.empty(self.hc_mult, self.hc_mult * config.hidden_size)) + self.hc_base = nn.Parameter(torch.empty(self.hc_mult)) + self.hc_scale = nn.Parameter(torch.empty(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + flat = self.input_norm(x.flatten(2).float()) + mixes = F.linear(flat, self.hc_fn.float()) + pre = torch.sigmoid(mixes * self.hc_scale.float() + self.hc_base.float()) + self.eps + return (pre.unsqueeze(-1) * x).sum(dim=2).to(x.dtype) + + +class DeepseekV4MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +@use_experts_implementation +class DeepseekV4Experts(nn.Module): + """Routed experts (paper §2.1). Inherits the Mixtral layout (no biases, + ``[num_experts, 2 * intermediate, hidden]`` ``gate_up_proj``) and the per-expert + iteration loop; the only V4-specific bit is the ``swiglu_limit`` clamp on the + gate / up halves before the SiLU mix. + + ``config.intermediate_size`` resolves to ``moe_intermediate_size`` via the + config's ``attribute_map``, so ``MixtralExperts.__init__`` builds the right + Linear shapes without an override. + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + self.limit = config.swiglu_limit + + def forward( + self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + ) -> torch.Tensor: + final = torch.zeros_like(hidden_states) + with torch.no_grad(): + mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + hit = torch.greater(mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(mask[expert_idx]) + gate, up = F.linear(hidden_states[token_idx], self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + current = self.act_fn(gate) * up + current = F.linear(current, self.down_proj[expert_idx]) * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, current.to(final.dtype)) + return final + + +class DeepseekV4TopKRouter(nn.Module): + """DeepSeekMoE top-k router (paper §2.1, "Mixture-of-Experts"). Two changes from + the V3 router: + + * The expert affinity activation is ``Sqrt(Softplus(·))`` instead of the V3 + Sigmoid (paper §2.1: "we change the activation function that computes the + affinity scores from Sigmoid(·) into Sqrt(Softplus(·))"). The ``scoring_func`` + config field selects this for V4 checkpoints. + * The constraint on the number of routing target nodes used in V3 is dropped, + and the V3 ``n_group`` / ``topk_group`` machinery is removed entirely (paper + §2.1: "we remove the constraint on the number of routing target nodes"). + + The auxiliary-loss-free strategy (DeepSeek's ``noaux_tc``) is preserved via the + per-expert ``e_score_correction_bias`` buffer that biases the top-k argmax + without flowing gradients. + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.score_fn = ACT2FN[config.scoring_func] + self.routed_scaling_factor = config.routed_scaling_factor + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts), persistent=True) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = F.linear(flat.float(), self.weight.float()) + scores = self.score_fn(logits) + indices = torch.topk(scores + self.e_score_correction_bias, self.top_k, dim=-1, sorted=False).indices + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class DeepseekV4HashRouter(nn.Module): + """Hash routing for the first ``mlp_layer_types == "hash_moe"`` MoE layers (paper + §2.1). Expert selection is determined by a fixed ``tid2eid[input_ids]`` lookup — + a frozen token-id → expert-id table — instead of a learned argmax. The learned + gate ``weight`` still produces the per-expert scores that weight the selected + experts' activations; only the *which-experts* selection is static. + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) + self.score_fn = ACT2FN[config.scoring_func] + self.routed_scaling_factor = config.routed_scaling_factor + self.register_buffer("tid2eid", torch.zeros(config.vocab_size, self.top_k, dtype=torch.long), persistent=True) + + def forward( + self, hidden_states: torch.Tensor, input_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = F.linear(flat.float(), self.weight.float()) + scores = self.score_fn(logits) + indices = self.tid2eid[input_ids.reshape(-1)].long() + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class DeepseekV4SparseMoeBlock(nn.Module): + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.is_hash = config.mlp_layer_types[layer_idx] == "hash_moe" + self.gate = DeepseekV4HashRouter(config) if self.is_hash else DeepseekV4TopKRouter(config) + self.experts = DeepseekV4Experts(config) + self.shared_experts = DeepseekV4MLP(config) + + def forward(self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None) -> torch.Tensor: + batch, seq_len, hidden_dim = hidden_states.shape + residual = hidden_states + flat = hidden_states.view(-1, hidden_dim) + if self.is_hash: + if input_ids is None: + raise ValueError( + "DeepseekV4's hash-routing layers need `input_ids` to look up expert indices. " + "The `inputs_embeds`-only inference path is not supported for models with " + "any `hash_moe` entries in `mlp_layer_types`." + ) + _, weights, indices = self.gate(hidden_states, input_ids) + else: + _, weights, indices = self.gate(hidden_states) + routed = self.experts(flat, indices, weights).view(batch, seq_len, hidden_dim) + return routed + self.shared_experts(residual) + + +class DeepseekV4DecoderLayer(GradientCheckpointingLayer): + r"""DeepSeek-V4 decoder block (paper §2). Differs from a classic residual block in + two places: + + * The residual is a stack of ``hc_mult`` parallel streams kept in shape + ``[B, S, hc_mult, D]`` throughout the block, mixed in and out via two + :class:`DeepseekV4HyperConnection` modules (Manifold-Constrained Hyper- + Connections / mHC, paper §2.2; Xie et al., 2026). The mHC mappings constrain + the residual transform to the manifold of doubly-stochastic matrices via the + Sinkhorn-Knopp projection — making signal propagation non-expansive across + deep stacks. + * ``self_attn`` is :class:`DeepseekV4Attention` for every layer. Its compressor + sub-module is the only thing that varies by layer type + (:class:`DeepseekV4HCACompressor` for HCA layers, + :class:`DeepseekV4CSACompressor` for CSA, picked via + ``config.layer_types[layer_idx]``); the CSA compressor also owns the + Lightning Indexer at ``self_attn.compressor.indexer``. + + Classic residual decoder layer:: + + h ──► norm ──► self_attn ──► + ──► norm ──► mlp ──► + + └──────── residual ────────┘ └─────── residual ───┘ + + Deepseek V4 decoder layer (``H = hc_mult`` parallel residual streams throughout):: + + attention site mlp site + ┌────────────────────────────────────────┐ ┌────────────────────────────────────────┐ + │ hidden_streams [B, S, H, D] │ │ hidden_streams [B, S, H, D] │ + │ │ │ │ │ │ + │ attn_hc(streams) ─► (pre, post, comb) │ │ ffn_hc(streams) ─► (pre, post, comb) │ + │ │ │ │ │ │ + │ Σ pre·streams (collapse) │ │ Σ pre·streams (collapse) │ + │ │ │ │ │ │ + │ input_layernorm │ │ post_attention_layernorm │ + │ │ │ │ │ │ + │ self_attn │ │ mlp (MoE routed + shared) │ + │ │ │ │ │ │ + │ post·output + comb·streams (expand) │ │ post·output + comb·streams (expand) │ + │ │ │ │ │ │ + │ ▼ │ │ ▼ │ + │ new hidden_streams ──────────────────┘ │ new hidden_streams │ + └────────────────────────────────────────┘ └────────────────────────────────────────┘ + + + """ + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.self_attn = DeepseekV4Attention(config, layer_idx) + self.mlp = DeepseekV4SparseMoeBlock(config, layer_idx) + self.input_layernorm = DeepseekV4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn_hc = DeepseekV4HyperConnection(config) + self.ffn_hc = DeepseekV4HyperConnection(config) + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + # hidden_states throughout: [B, S, hc_mult, hidden]. + + # --- Attention site: collapse → norm → attn → expand --- + post, comb, collapsed = self.attn_hc(hidden_states) + attn_output, _ = self.self_attn(self.input_layernorm(collapsed), **kwargs) + dtype = hidden_states.dtype + hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype), hidden_states + ) + + # --- MLP site: collapse → norm → mlp → expand --- + post, comb, collapsed = self.ffn_hc(hidden_states) + mlp_output = self.mlp(self.post_attention_layernorm(collapsed), input_ids=input_ids) + dtype = hidden_states.dtype + return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul(comb.to(dtype), hidden_states) + + +@auto_docstring +class DeepseekV4PreTrainedModel(PreTrainedModel): + config: DeepseekV4Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV4DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + # V4 ships eager-only: the compressor / indexer paths weren't validated against + # SDPA / FlashAttention / FlexAttention kernels — leaving these ``False`` makes + # ``set_attn_implementation`` reject those backends instead of silently routing + # through them. + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False + # The compressor's rolling-window buffer / pool / overlap state lives on the + # per-layer cache (:class:`DeepseekV4HCACache` / :class:`DeepseekV4CSACache`) + # and isn't compatible with :class:`StaticCache` — that path would hand the + # compressor a :class:`StaticSlidingWindowLayer` with no ``update_compressor`` + # method. Disabling fullgraph compile keeps generation tests on the dynamic + # cache build that does dispatch to V4's own cache layers. + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "router_logits": OutputRecorder(DeepseekV4TopKRouter, index=0), + "hidden_states": DeepseekV4DecoderLayer, + "attentions": DeepseekV4Attention, + } + config_class = DeepseekV4Config + _keep_in_fp32_modules_strict = ["attn_hc", "ffn_hc"] + _keys_to_ignore_on_load_unexpected = [r"model\.mtp\..*"] + # ``_is_stateful`` opts out of generation modes that need to roll the cache + # back across drafts (assisted generation, prompt lookup, contrastive search). + # The compressor's running-window state isn't rewindable, so ``generate`` + # raises a clear error early instead of failing deep in the compressor with + # a missing-method ``AttributeError``. + _is_stateful = True + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, (DeepseekV4TopKRouter, DeepseekV4HashRouter)): + init.normal_(module.weight, mean=0.0, std=std) + if isinstance(module, DeepseekV4TopKRouter): + init.zeros_(module.e_score_correction_bias) # buffer + if isinstance(module, DeepseekV4HashRouter): + init.zeros_(module.tid2eid) # buffer; real values come from the checkpoint + elif isinstance(module, DeepseekV4Experts): + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, DeepseekV4Attention): + init.zeros_(module.sinks) + elif isinstance(module, DeepseekV4HyperConnection): + init.normal_(module.fn, mean=0.0, std=std) + init.zeros_(module.base) + init.ones_(module.scale) + elif isinstance(module, DeepseekV4HyperHead): + init.normal_(module.hc_fn, mean=0.0, std=std) + init.zeros_(module.hc_base) + init.ones_(module.hc_scale) + elif isinstance(module, (DeepseekV4HCACompressor, DeepseekV4CSACompressor, DeepseekV4Indexer)): + init.zeros_(module.position_bias) + elif isinstance(module, DeepseekV4RotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq) + init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq) + + +@auto_docstring +class DeepseekV4Model(DeepseekV4PreTrainedModel): + def __init__(self, config: DeepseekV4Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DeepseekV4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekV4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV4RotaryEmbedding(config) + self.gradient_checkpointing = False + self.hc_head = DeepseekV4HyperHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + return_cache = past_key_values if use_cache else None + if past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: + past_seen = past_key_values.get_seq_length() + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen + position_ids = position_ids.unsqueeze(0) + # ``generate()`` may pass a per-layer-type mask dict already built by + # ``create_masks_for_generate``; all V4 layer types use the same sliding-window + # mask, so use the prebuilt one directly. Otherwise build it here. + if isinstance(attention_mask, dict): + causal_mask = next(iter(attention_mask.values())) + else: + causal_mask = create_sliding_window_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1).contiguous() + position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main") + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + position_ids=position_ids, + attention_mask=causal_mask, + input_ids=input_ids, + past_key_values=past_key_values, + **kwargs, + ) + + hidden_states = self.norm(self.hc_head(hidden_states)) + return MoeModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=return_cache) + + +def load_balancing_loss_func( + gate_logits: torch.Tensor | tuple[torch.Tensor] | None, + num_experts: int | None = None, + top_k=2, + attention_mask: torch.Tensor | None = None, +) -> torch.Tensor | int: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +@auto_docstring +class DeepseekV4ForCausalLM(DeepseekV4PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV4Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_router_logits: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV4ForCausalLM + + >>> model = DeepseekV4ForCausalLM.from_pretrained("mistralai/DeepseekV4-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/DeepseekV4-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +__all__ = ["DeepseekV4PreTrainedModel", "DeepseekV4Model", "DeepseekV4ForCausalLM"] diff --git a/src/transformers/models/deepseek_v4/modular_deepseek_v4.py b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py new file mode 100644 index 000000000000..3e7402bd7ace --- /dev/null +++ b/src/transformers/models/deepseek_v4/modular_deepseek_v4.py @@ -0,0 +1,1292 @@ +# Copyright 2026 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +from collections.abc import Callable + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowLayer +from ...integrations import use_experts_implementation +from ...masking_utils import create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.generic import maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import OutputRecorder, capture_outputs +from ..deepseek_v3.modeling_deepseek_v3 import DeepseekV3RMSNorm +from ..glm.modeling_glm import rotate_half +from ..gpt_oss.modeling_gpt_oss import eager_attention_forward +from ..laguna.modeling_laguna import LagunaRotaryEmbedding +from ..llama.modeling_llama import LlamaMLP, LlamaModel +from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM, MixtralPreTrainedModel, MixtralTopKRouter +from .configuration_deepseek_v4 import DeepseekV4Config + + +def apply_rotary_pos_emb( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> torch.Tensor: + """V4-Flash interleaved RoPE (matches the reference's ``apply_rotary_emb`` at + ``inference/model.py:232``). Rotates the **trailing** ``2 * cos.shape[-1]`` channels + of ``x`` (V4-Flash lays each head out as ``[nope | rope]``, matching the reference's + ``x[..., -rd:]`` indexing) and leaves the leading nope channels unchanged. The + half-sized cos / sin from :class:`DeepseekV4RotaryEmbedding` are expanded to the + pair-aligned full rope dim via ``repeat_interleave`` here, and the rotation is + done in fp32 (matching the reference's precision).""" + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(unsqueeze_dim) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(unsqueeze_dim) + rope_dim = cos.shape[-1] + nope, rope = x[..., :-rope_dim], x[..., -rope_dim:] + rotated = ((rope.float() * cos) + (rotate_half(rope).float() * sin)).to(x.dtype) + return torch.cat([nope, rotated], dim=-1) + + +logger = logging.get_logger(__name__) + + +class DeepseekV4RMSNorm(DeepseekV3RMSNorm): + pass + + +class DeepseekV4UnweightedRMSNorm(nn.Module): + """RMSNorm without a learned weight — applied per-head to Q after ``q_b_proj`` + in :class:`DeepseekV4Attention`. Matches the V4-Flash reference's ``inference/ + model.py:498`` rescale ``q *= rsqrt(mean(q**2) + eps)``; without it attention + scores end up at the wrong scale and the model collapses to a single repeated + token within a handful of layers. + """ + + def __init__(self, eps: float = 1.0e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.rsqrt(x.float().square().mean(-1, keepdim=True) + self.eps).to(x.dtype) + + +class DeepseekV4RotaryEmbedding(LagunaRotaryEmbedding): + """Multi-layer-type rotary embedding (Laguna pattern: partial rotary on top of + Gemma3's per-layer-type buffers), specialised for V4's *interleaved* RoPE. Holds + two ``inv_freq`` buffers — ``"main"`` for self-attention (``rope_theta``) and + ``"compress"`` for the Compressor / Indexer (``compress_rope_theta``); both + honour ``partial_rotary_factor`` so cos/sin sizes to ``qk_rope_head_dim``. + ``forward(x, position_ids, layer_type=...)`` returns the half-sized cos/sin + directly — interleaved RoPE rotates pairs ``(x[2i], x[2i+1])`` so we want one + ``θ_i`` per pair, *not* the end-to-end duplicated table half-split RoPE needs. + + The ``layer_types`` here are the *rope* layer types (``"main"`` / ``"compress"``), + keys of ``config.rope_parameters``. They are unrelated to ``config.layer_types``, + which lists the per-decoder-block attention type. + """ + + # Class-level rather than ``list(set(config.layer_types))`` (gemma3's pattern): + # V4's rope keys ``"main"`` / ``"compress"`` are *orthogonal* to the per-block + # attention types in ``config.layer_types`` — every attention block uses ``main``, + # every compressor / indexer uses ``compress``, regardless of which of the three + # block types the layer is. + layer_types = ("main", "compress") + + def __init__(self, config: "DeepseekV4Config", device=None): + nn.Module.__init__(self) + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + self.rope_type = {} + for layer_type in set(self.layer_types): + params = config.rope_parameters.get(layer_type) + if params is None: + continue + self.rope_type[layer_type] = params.get("rope_type", "default") + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]] + inv_freq, scaling = rope_init_fn(config, device, layer_type=layer_type) + self.register_buffer(f"{layer_type}_inv_freq", inv_freq, persistent=False) + self.register_buffer(f"{layer_type}_original_inv_freq", inv_freq.clone(), persistent=False) + setattr(self, f"{layer_type}_attention_scaling", scaling) + + def forward(self, x, position_ids, layer_type=None): + # Interleaved RoPE: one ``θ_i`` per pair (``rope_head_dim // 2`` entries), + # no end-to-end duplication. Same shape as ``inv_freq @ position_ids``. + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + cos = freqs.cos() * attention_scaling + sin = freqs.sin() * attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def _sliding_kv_update( + cache_layer: "DynamicSlidingWindowLayer", key_states: torch.Tensor, value_states: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Shared sliding-window K=V update body. V4 uses shared-KV MQA, so ``keys`` and + ``values`` point to the same storage on every layer; both V4 cache layer types + (HCA / CSA) call this from their ``update``.""" + if not cache_layer.is_initialized: + cache_layer.lazy_initialization(key_states, value_states) + cache_layer.values = cache_layer.keys + cache_layer.cumulative_length += key_states.shape[-2] + full = torch.cat([cache_layer.keys, key_states], dim=-2) + cache_layer.keys = full[:, :, -cache_layer.sliding_window + 1 :, :] + cache_layer.values = cache_layer.keys + return full, full + + +def _update_window_buffer( + buffer_kv: torch.Tensor | None, + buffer_gate: torch.Tensor | None, + kv: torch.Tensor, + gate: torch.Tensor, + compress_rate: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Merge a still-buffered tail with freshly projected ``(kv, gate)`` and split off + the longest window-aligned chunk. Used by both the compressor- and indexer-side + window buffers; tokens past the last full window stay in the buffer until the + next call rounds them out to a multiple of ``compress_rate``.""" + if buffer_kv is not None and buffer_kv.shape[1]: + kv = torch.cat([buffer_kv, kv], dim=1) + gate = torch.cat([buffer_gate, gate], dim=1) + usable = (kv.shape[1] // compress_rate) * compress_rate + return kv[:, :usable], gate[:, :usable], kv[:, usable:], gate[:, usable:] + + +def _append_to_pool(pool: torch.Tensor | None, new_pooled: torch.Tensor) -> torch.Tensor: + """Append freshly emitted compressed entries to a running pool, returning the + full pool (or an empty tensor if nothing has been pooled yet).""" + if new_pooled.shape[1] > 0: + return new_pooled if pool is None else torch.cat([pool, new_pooled], dim=1) + if pool is None: + return new_pooled.new_zeros((new_pooled.shape[0], 0, new_pooled.shape[-1])) + return pool + + +class DeepseekV4HCACache(DynamicSlidingWindowLayer): + """Cache layer for HCA blocks (paper §2.3.2). Holds the long-range compressor's + buffer / pool / count on top of the sliding-window K=V branch. HCA uses + *non-overlapping* windows, so there is **no** overlap state, and HCA has **no** + indexer either. + + Fields on top of :class:`DynamicSlidingWindowLayer`: + + * ``compressor_pool`` — the running list of compressed KV entries emitted so + far (one per ``compress_rate_hca`` source tokens; the long-range KVs the + attention concatenates onto its sliding-window keys / values). + * ``compressor_buffer_kv`` / ``compressor_buffer_gate`` — source tokens that + arrived between two full windows; once the buffer hits ``compress_rate_hca`` + tokens the compressor closes a window, emits one pooled entry, and drains + the buffer. + * ``compressor_pool_count`` — number of compressed entries emitted so far, + so ``compressor_pool_count * compress_rate_hca`` is the absolute position + of the *next* window's first source token. + + The class-level ``layer_type`` auto-registers this class with + :data:`LAYER_TYPE_CACHE_MAPPING` so :class:`DynamicCache` builds it on its own + when ``config.layer_types[i] == "heavily_compressed_attention"``. + """ + + layer_type = "heavily_compressed_attention" + + def __init__(self, config: "DeepseekV4Config"): + super().__init__(config) + self.compress_rate = config.compress_rates["heavily_compressed_attention"] + self.compressor_buffer_kv: torch.Tensor | None = None + self.compressor_buffer_gate: torch.Tensor | None = None + self.compressor_pool: torch.Tensor | None = None + self.compressor_pool_count = 0 + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs): + return _sliding_kv_update(self, key_states, value_states) + + def update_compressor(self, kv: torch.Tensor, gate: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """Merge the freshly projected ``(kv, gate)`` (paper §2.3.2 eqs. 20–21: + ``C = H·W^{KV}``, ``Z = H·W^Z``) with the buffered tail from prior calls and + return the longest window-aligned chunk that's ready to pool, plus the + absolute source-token position of that chunk's first window. The returned + chunk is softmax-pooled by the compressor with ``position_bias`` to emit one + compressed entry per window of ``compress_rate_hca`` tokens (eqs. 22–23).""" + first_pool_position = self.compressor_pool_count * self.compress_rate + chunk_kv, chunk_gate, self.compressor_buffer_kv, self.compressor_buffer_gate = _update_window_buffer( + self.compressor_buffer_kv, self.compressor_buffer_gate, kv, gate, self.compress_rate + ) + return chunk_kv, chunk_gate, first_pool_position + + def update_compressor_pool(self, new_pooled: torch.Tensor) -> torch.Tensor: + """Append freshly emitted compressed entries to ``compressor_pool`` + (``C^{Comp}``, paper §2.3.2 eq. 23) and return the full pool. Bumps + ``compressor_pool_count`` so the next ``update_compressor`` call knows the + absolute source-token position of its first window.""" + self.compressor_pool = _append_to_pool(self.compressor_pool, new_pooled) + self.compressor_pool_count += new_pooled.shape[1] + return self.compressor_pool + + +class DeepseekV4CSACache(DynamicSlidingWindowLayer): + """Cache layer for CSA blocks (paper §2.3.1). Holds two parallel sets of + buffer / pool / count / overlap state on top of the sliding-window K=V branch: + + * **compressor side** — the main-branch ``head_dim`` pool (the long-range KVs + the attention concatenates after top-k indexer selection). + * **indexer side** — the Lightning Indexer's smaller ``index_head_dim`` pool + (the keys ``K^{IComp}`` that queries score against to pick the top-k blocks, + eqs. 14–17). Kept separate from the compressor pool because the head dim + differs. + + Both sides use **overlapping** windows of stride ``compress_rate_csa`` and width + ``2 * compress_rate_csa`` (paper §2.3.1), so each side also keeps an + ``*_overlap_kv`` / ``*_overlap_gate`` pair holding the last full window's + projected ``(kv, gate)`` so the next forward call's first window can stitch in + its low-channel slice as the prior contribution. + + The class-level ``layer_type`` auto-registers this class with + :data:`LAYER_TYPE_CACHE_MAPPING` so :class:`DynamicCache` builds it on its own + when ``config.layer_types[i] == "compressed_sparse_attention"``. + """ + + layer_type = "compressed_sparse_attention" + + def __init__(self, config: "DeepseekV4Config"): + super().__init__(config) + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + # Compressor side + self.compressor_buffer_kv: torch.Tensor | None = None + self.compressor_buffer_gate: torch.Tensor | None = None + self.compressor_pool: torch.Tensor | None = None + self.compressor_pool_count = 0 + self.compressor_overlap_kv: torch.Tensor | None = None + self.compressor_overlap_gate: torch.Tensor | None = None + # Indexer side (parallel state at ``index_head_dim``) + self.indexer_buffer_kv: torch.Tensor | None = None + self.indexer_buffer_gate: torch.Tensor | None = None + self.indexer_pool: torch.Tensor | None = None + self.indexer_pool_count = 0 + self.indexer_overlap_kv: torch.Tensor | None = None + self.indexer_overlap_gate: torch.Tensor | None = None + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, *args, **kwargs): + return _sliding_kv_update(self, key_states, value_states) + + def update_compressor(self, kv: torch.Tensor, gate: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """Compressor-side window buffer (paper §2.3.1 main-branch pool, eqs. 9–12). + Same window-aligned tail-buffering as HCA, but at the CSA cadence + (``compress_rate_csa``).""" + first_pool_position = self.compressor_pool_count * self.compress_rate + chunk_kv, chunk_gate, self.compressor_buffer_kv, self.compressor_buffer_gate = _update_window_buffer( + self.compressor_buffer_kv, self.compressor_buffer_gate, kv, gate, self.compress_rate + ) + return chunk_kv, chunk_gate, first_pool_position + + def update_compressor_pool(self, new_pooled: torch.Tensor) -> torch.Tensor: + """Append freshly emitted entries to the CSA compressor pool (the + ``C^{Comp}`` running list at ``head_dim``, eqs. 11–12).""" + self.compressor_pool = _append_to_pool(self.compressor_pool, new_pooled) + self.compressor_pool_count += new_pooled.shape[1] + return self.compressor_pool + + def get_compressor_overlap(self) -> tuple[torch.Tensor | None, torch.Tensor | None]: + return self.compressor_overlap_kv, self.compressor_overlap_gate + + def set_compressor_overlap(self, kv: torch.Tensor, gate: torch.Tensor) -> None: + self.compressor_overlap_kv = kv + self.compressor_overlap_gate = gate + + def update_indexer(self, kv: torch.Tensor, gate: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: + """Indexer-side mirror of :meth:`update_compressor` (paper §2.3.1, "Lightning + Indexer for Sparse Selection"). Same logic at the smaller ``index_head_dim`` + — the small-head pool keys ``K^{IComp}`` (eq. 14's ``W^{IUQ}`` complement on + the key side) that the indexer scores queries against to pick the top-k + blocks (eqs. 15–17). Buffer / pool / count are kept separate from the + compressor's state because the head dim differs.""" + first_pool_position = self.indexer_pool_count * self.compress_rate + chunk_kv, chunk_gate, self.indexer_buffer_kv, self.indexer_buffer_gate = _update_window_buffer( + self.indexer_buffer_kv, self.indexer_buffer_gate, kv, gate, self.compress_rate + ) + return chunk_kv, chunk_gate, first_pool_position + + def update_indexer_pool(self, new_pooled: torch.Tensor) -> torch.Tensor: + """Append freshly emitted entries to the indexer pool ``K^{IComp}`` (paper + §2.3.1 eq. 16: the keys against which the ``q^I_t`` queries score for top-k + selection). Same cadence as the compressor pool — one entry per + ``compress_rate_csa`` source tokens — but at ``index_head_dim``.""" + self.indexer_pool = _append_to_pool(self.indexer_pool, new_pooled) + self.indexer_pool_count += new_pooled.shape[1] + return self.indexer_pool + + def get_indexer_overlap(self) -> tuple[torch.Tensor | None, torch.Tensor | None]: + return self.indexer_overlap_kv, self.indexer_overlap_gate + + def set_indexer_overlap(self, kv: torch.Tensor, gate: torch.Tensor) -> None: + self.indexer_overlap_kv = kv + self.indexer_overlap_gate = gate + + +class DeepseekV4GroupedLinear(nn.Linear): + """Block-diagonal grouped linear used by the V4 grouped output projection + (paper §2.3.1, "Grouped Output Projection"; HCA reuses the same scheme, + §2.3.2). With ``num_attention_heads = n_h`` and per-head dim ``c``, the core + attention's stacked output is ``c·n_h``-dim, which is *very* large for V4 + (V4-Flash: c=512, n_h=64 → 32768; V4-Pro: c=512, n_h=128 → 65536). A direct + ``c·n_h → hidden_size`` projection would dominate the per-token cost. + + The paper sidesteps that by splitting the n_h heads into ``g`` groups, projecting + each ``c·n_h/g``-dim group independently to a ``d_g``-dim intermediate output + (with ``d_g < c·n_h/g``), and then mixing the resulting ``g·d_g`` vector to + ``hidden_size`` through a single follow-up linear (``self_attn.o_b_proj``). This + module owns the per-group block (``self_attn.o_a_proj``). + + The ``weight`` parameter is shaped like a standard ``nn.Linear`` + (``[out_features, in_features_per_group]``) so quantizers keyed on + ``nn.Linear.weight`` still pick it up; ``forward`` does the per-group ``bmm``. + """ + + def __init__(self, in_features_per_group: int, out_features: int, n_groups: int, bias: bool = False): + super().__init__(in_features_per_group, out_features, bias=bias) + self.n_groups = n_groups + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [..., n_groups, in_features_per_group] + input_shape = x.shape[:-2] + d_in = x.shape[-1] + w = self.weight.view(self.n_groups, -1, d_in).transpose(1, 2) + x = x.reshape(-1, self.n_groups, d_in).transpose(0, 1) + y = torch.bmm(x, w).transpose(0, 1) + return y.reshape(*input_shape, self.n_groups, -1) + + +def _overlap_pool( + chunk_kv: torch.Tensor, + chunk_gate: torch.Tensor, + prior_kv: torch.Tensor | None, + prior_gate: torch.Tensor | None, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Expand ``[B, n_win, ratio, 2*head_dim]`` chunks into ``[B, n_win, 2*ratio, head_dim]`` + by stitching each window's *low-channel* slice onto the *high-channel* slice of the + prior window — matching the V4-Flash reference (``Compressor.overlap_transform``). + + Each pooled output thus mixes ``ratio`` *current* source tokens (high half of the + learned 2d split) with ``ratio`` *previous* source tokens (low half), so windows + have width ``2*ratio`` but stride ``ratio`` (paper §2.3.1). For window 0, the prior + half is filled with zero (kv) / ``-inf`` (gate, so its softmax weight is exactly 0), + unless ``prior_kv`` / ``prior_gate`` carry the last full window from a previous + forward call — in which case its low-channel slice slots into row ``[0, :ratio]``. + """ + batch, n_windows, ratio, _ = chunk_kv.shape + new_kv = chunk_kv.new_zeros((batch, n_windows, 2 * ratio, head_dim)) + new_gate = chunk_gate.new_full((batch, n_windows, 2 * ratio, head_dim), float("-inf")) + new_kv[:, :, ratio:] = chunk_kv[..., head_dim:] + new_gate[:, :, ratio:] = chunk_gate[..., head_dim:] + if n_windows > 1: + new_kv[:, 1:, :ratio] = chunk_kv[:, :-1, :, :head_dim] + new_gate[:, 1:, :ratio] = chunk_gate[:, :-1, :, :head_dim] + if prior_kv is not None and prior_gate is not None: + new_kv[:, 0, :ratio] = prior_kv[..., :head_dim].to(new_kv.dtype) + new_gate[:, 0, :ratio] = prior_gate[..., :head_dim].to(new_gate.dtype) + return new_kv, new_gate + + +def _rope_pool(pooled: torch.Tensor, rotary_emb: nn.Module, positions: torch.Tensor, layer_type: str) -> torch.Tensor: + """Apply RoPE to the trailing rope slice of each pooled entry at its deterministic + absolute position. Used by both the indexer pool and the HCA / CSA compressor pools.""" + cos, sin = rotary_emb(pooled, position_ids=positions, layer_type=layer_type) + return apply_rotary_pos_emb(pooled.unsqueeze(1), cos, sin).squeeze(1) + + +class DeepseekV4Indexer(nn.Module): + """Lightning Indexer (paper §2.3.1, eqs. 13–17). Used by Compressed Sparse + Attention (CSA) to pick the top-k compressed KV blocks per query. The indexer + runs its own scaled-down compressor at ``index_head_dim`` over the same windows + as the outer CSA compressor, then scores queries against the pooled keys with + ``∑_h w_{t,h} · ReLU(q_{t,h} · K^IComp_s)`` and keeps the top ``index_topk`` + indices. + + Class-attribute ``rope_layer_type`` selects which inv_freq buffer of the shared + :class:`DeepseekV4RotaryEmbedding` to use; the indexer always reads + ``"compress"`` (paired with ``compress_rope_theta``). + + The indexer has its own rotary because it applies RoPE to two sets of tensors: + + * **pool keys** at deterministic positions ``i * compress_rate + first_pool_position``, + * **queries** at the model's current ``position_ids`` (variable per forward). + + Both must use the same theta as the outer compressor (``compress_rope_theta``) so + query/key inner products are translation-invariant in the standard rope sense — if + they used different thetas the score ``q · k`` would carry a residual position- + dependent skew. We can't precompute cos/sin once at init because the query + positions vary per call, so the indexer owns a rotary embedding and calls it with + ``layer_type=self.rope_layer_type`` twice per forward (once for pool keys, once for queries). + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + self.n_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.index_topk = config.index_topk + self.softmax_scale = self.head_dim**-0.5 + self.weights_scaling = self.n_heads**-0.5 + # The indexer always pools with the CSA cadence (``compress_rate=4``), so its + # inner pool runs the same overlapping-window scheme as :class:`DeepseekV4CSACompressor` + # (paper §2.3.1) — ``coff = 2`` everywhere on the pool branch. + self.coff = 2 + self.kv_proj = nn.Linear(config.hidden_size, self.coff * self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, self.coff * self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, self.coff * self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.n_heads * self.head_dim, bias=False) + self.weights_proj = nn.Linear(config.hidden_size, self.n_heads, bias=False) + self.rotary_emb = DeepseekV4RotaryEmbedding(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> torch.LongTensor: + batch, seq_len, _ = hidden_states.shape + cache_layer = past_key_values.layers[layer_idx] if past_key_values is not None else None + + # --- Pool side: same overlapping windows as the outer CSA compressor, at index_head_dim --- + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_pool_position = kv[:, :usable], gate[:, :usable], 0 + prior_kv, prior_gate = None, None + else: + chunk_kv, chunk_gate, first_pool_position = cache_layer.update_indexer(kv, gate) + prior_kv, prior_gate = cache_layer.get_indexer_overlap() + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, self.compress_rate, -1) + chunk_gate = chunk_gate.view(batch, n_windows, self.compress_rate, -1) + self.position_bias.to( + chunk_gate.dtype + ) + if cache_layer is not None: + cache_layer.set_indexer_overlap(chunk_kv[:, -1].clone(), chunk_gate[:, -1].clone()) + chunk_kv, chunk_gate = _overlap_pool(chunk_kv, chunk_gate, prior_kv, prior_gate, self.head_dim) + # Softmax in fp32 for stability (logits in bf16/fp16 can collapse pairs that + # only differ by a small amount, especially with large window widths). + new_pooled = self.kv_norm( + (chunk_kv * chunk_gate.softmax(dim=2, dtype=torch.float32).to(chunk_kv.dtype)).sum(dim=2) + ) + positions = ( + (torch.arange(n_windows, device=new_pooled.device) * self.compress_rate + first_pool_position) + .unsqueeze(0) + .expand(batch, -1) + ) + new_pooled = _rope_pool(new_pooled, self.rotary_emb, positions, self.rope_layer_type) + else: + new_pooled = chunk_kv.new_zeros((batch, 0, self.head_dim)) + pooled_kv = new_pooled if cache_layer is None else cache_layer.update_indexer_pool(new_pooled) + + # --- Query side --- + cos_q, sin_q = self.rotary_emb(hidden_states, position_ids=position_ids, layer_type=self.rope_layer_type) + q = self.q_b_proj(q_residual).view(batch, seq_len, -1, self.head_dim).transpose(1, 2) + q = apply_rotary_pos_emb(q, cos_q, sin_q).transpose(1, 2) + + # --- Score: ReLU(q·kᵀ) * weights, then top-k --- + scores = torch.matmul(q.float(), pooled_kv.transpose(-1, -2).float().unsqueeze(1)) # [B, S, H, T] + scores = F.relu(scores) * self.softmax_scale + weights = self.weights_proj(hidden_states).float() * self.weights_scaling # [B, S, H] + index_scores = (scores * weights.unsqueeze(-1)).sum(dim=2) # [B, S, T] + topk = min(self.index_topk, pooled_kv.shape[1]) + return index_scores.topk(topk, dim=-1).indices + + +class DeepseekV4HCACompressor(nn.Module): + """Heavily Compressed Attention compressor (paper §2.3.2, eqs. 20–23). Pools + every ``compress_rate_hca`` (m'=128) source tokens into a single compressed KV + entry with **non-overlapping** windows — no overlap state, no indexer. + + The three building blocks (paper notation in parentheses): + + * **kv** = ``kv_proj(hidden_states)`` — head-dim KV projection ``C ∈ R^{n×c}`` + (eq. 20). Doubles as both key and value (shared-KV MQA). + * **gate** = ``gate_proj(hidden_states)`` — head-dim compression weights + ``Z ∈ R^{n×c}`` (eq. 21). Combined with ``position_bias`` and softmaxed per + window to produce the convex combination that mixes ``compress_rate_hca`` + source KVs into one pooled entry. + * **pool** — running list of compressed KV entries (``C^{Comp}``, eq. 23). + Lives on :class:`DeepseekV4HCACache`; the in-flight buffer of tokens that + haven't yet filled a window lives there too. + + Each closed window of m' tokens produces one pooled entry: + ``C^{Comp}_i = Σ_{j∈window} softmax(Z_j + B)_j ⊙ C_j``. RoPE on the trailing + ``rope_head_dim`` slice is applied at the deterministic absolute position + ``i * compress_rate_hca + first_pool_position`` so cross-call concatenation stays + causality-correct. Returns the running pool ``[B, 1, T, head_dim]``. + + When ``past_key_values is None`` (a checkpoint replay zeroes the cache to break + the grad-cache loop), runs in stateless single-shot mode: pool every complete + window from ``hidden_states`` and discard the remainder. + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["heavily_compressed_attention"] + self.head_dim = config.head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.kv_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV4RotaryEmbedding(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> torch.Tensor: + # ``q_residual`` / ``position_ids`` are unused — the uniform forward signature + # lets :class:`DeepseekV4Attention` call either compressor without branching. + batch, _, _ = hidden_states.shape + cache_layer = past_key_values.layers[layer_idx] if past_key_values is not None else None + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_pool_position = kv[:, :usable], gate[:, :usable], 0 + else: + chunk_kv, chunk_gate, first_pool_position = cache_layer.update_compressor(kv, gate) + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, self.compress_rate, -1) + chunk_gate = chunk_gate.view(batch, n_windows, self.compress_rate, -1) + self.position_bias.to( + chunk_gate.dtype + ) + # Softmax in fp32 for stability (logits in bf16/fp16 can collapse pairs that + # only differ by a small amount, especially with large window widths). + new_pooled = self.kv_norm( + (chunk_kv * chunk_gate.softmax(dim=2, dtype=torch.float32).to(chunk_kv.dtype)).sum(dim=2) + ) + positions = ( + (torch.arange(n_windows, device=new_pooled.device) * self.compress_rate + first_pool_position) + .unsqueeze(0) + .expand(batch, -1) + ) + new_pooled = _rope_pool(new_pooled, self.rotary_emb, positions, self.rope_layer_type) + else: + new_pooled = chunk_kv.new_zeros((batch, 0, self.head_dim)) + if cache_layer is None: + return new_pooled.unsqueeze(1) + return cache_layer.update_compressor_pool(new_pooled).unsqueeze(1) + + +class DeepseekV4CSACompressor(nn.Module): + """Compressed Sparse Attention compressor (paper §2.3.1, eqs. 9–17). Pools every + ``compress_rate_csa`` (m=4) source tokens with **overlapping** windows — stride + ``compress_rate_csa`` and effective width ``2 * compress_rate_csa`` — and runs a + Lightning Indexer on top of the pool that scores queries with + ``∑_h w_{t,h} · ReLU(q_{t,h} · K^{IComp}_s)`` to gather the top ``index_topk`` + entries per query before they reach core attention. + + Compared to :class:`DeepseekV4HCACompressor` the differences are explicit: + + * ``kv_proj`` / ``gate_proj`` / ``position_bias`` project to **2 × head_dim** (the + learned channel split — high half pools into the current window, low half + pools into the next window's overlap with this one, see :func:`_overlap_pool`). + * The cache layer's ``compressor_overlap_*`` state carries the last full + window across forward calls. + * A :class:`DeepseekV4Indexer` sub-module gathers the top-``index_topk`` pool + entries per query (paper §2.3.1, "Lightning Indexer for Sparse Selection"). + """ + + rope_layer_type = "compress" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.compress_rate = config.compress_rates["compressed_sparse_attention"] + self.head_dim = config.head_dim + self.rope_head_dim = config.qk_rope_head_dim + # ``2 * head_dim`` because windows overlap: each pooled entry is a softmax-gated + # convex combination of ``compress_rate_csa`` *current* tokens (high-channel half) + # mixed with ``compress_rate_csa`` *previous* tokens (low-channel half). The + # learned channel split happens in :func:`_overlap_pool`. + self.kv_proj = nn.Linear(config.hidden_size, 2 * self.head_dim, bias=False) + self.gate_proj = nn.Linear(config.hidden_size, 2 * self.head_dim, bias=False) + self.position_bias = nn.Parameter(torch.empty(self.compress_rate, 2 * self.head_dim)) + self.kv_norm = DeepseekV4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV4RotaryEmbedding(config) + self.indexer = DeepseekV4Indexer(config) + + def forward( + self, + hidden_states: torch.Tensor, + q_residual: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache | None, + layer_idx: int, + ) -> torch.Tensor: + batch, seq_len, _ = hidden_states.shape + cache_layer = past_key_values.layers[layer_idx] if past_key_values is not None else None + kv = self.kv_proj(hidden_states) + gate = self.gate_proj(hidden_states) + if cache_layer is None: + usable = (kv.shape[1] // self.compress_rate) * self.compress_rate + chunk_kv, chunk_gate, first_pool_position = kv[:, :usable], gate[:, :usable], 0 + prior_kv, prior_gate = None, None + else: + chunk_kv, chunk_gate, first_pool_position = cache_layer.update_compressor(kv, gate) + prior_kv, prior_gate = cache_layer.get_compressor_overlap() + if chunk_kv.shape[1] > 0: + n_windows = chunk_kv.shape[1] // self.compress_rate + chunk_kv = chunk_kv.view(batch, n_windows, self.compress_rate, -1) + chunk_gate = chunk_gate.view(batch, n_windows, self.compress_rate, -1) + self.position_bias.to( + chunk_gate.dtype + ) + if cache_layer is not None: + # Persist the *raw* last full window (gate already biased) so the next + # forward call's first window can read its low-channel slice as prior. + cache_layer.set_compressor_overlap(chunk_kv[:, -1].clone(), chunk_gate[:, -1].clone()) + chunk_kv, chunk_gate = _overlap_pool(chunk_kv, chunk_gate, prior_kv, prior_gate, self.head_dim) + # Softmax in fp32 for stability (logits in bf16/fp16 can collapse pairs that + # only differ by a small amount, especially with large window widths). + new_pooled = self.kv_norm( + (chunk_kv * chunk_gate.softmax(dim=2, dtype=torch.float32).to(chunk_kv.dtype)).sum(dim=2) + ) + positions = ( + (torch.arange(n_windows, device=new_pooled.device) * self.compress_rate + first_pool_position) + .unsqueeze(0) + .expand(batch, -1) + ) + new_pooled = _rope_pool(new_pooled, self.rotary_emb, positions, self.rope_layer_type) + else: + new_pooled = chunk_kv.new_zeros((batch, 0, self.head_dim)) + pooled = ( + new_pooled.unsqueeze(1) + if cache_layer is None + else cache_layer.update_compressor_pool(new_pooled).unsqueeze(1) + ) + # Lightning Indexer: gather top-``index_topk`` pool entries per query. + topk = self.indexer(hidden_states, q_residual, position_ids, past_key_values, layer_idx) # [B, S, k] + expanded = pooled.unsqueeze(2).expand(-1, -1, seq_len, -1, -1) + idx = topk.unsqueeze(1).unsqueeze(-1).expand(-1, 1, -1, -1, self.head_dim) + return torch.gather(expanded, 3, idx).reshape(batch, 1, -1, self.head_dim) + + +COMPRESSOR_CLASSES = { + "sliding_attention": None, + "compressed_sparse_attention": DeepseekV4CSACompressor, + "heavily_compressed_attention": DeepseekV4HCACompressor, +} + + +# ----------------------------------------------------------------------------- +# Attention with sink. +# ----------------------------------------------------------------------------- + + +class DeepseekV4Attention(nn.Module): + """V4 attention block (paper §2.3). Single class for all three layer types — the + only thing that varies is the long-range branch (the ``compressor`` sub-module); + the surrounding QKV / RoPE / sink / sliding-window / output projection is + identical. The three layer types are dispatched by ``COMPRESSOR_CLASSES``: + + * ``sliding_attention``: ``compressor = None``; only the local sliding-window + K=V branch ("Full Attention"). + * ``compressed_sparse_attention``: :class:`DeepseekV4CSACompressor` — + low-compression overlapping-window pool plus a Lightning Indexer that keeps + the top-``index_topk`` pool entries per query (paper §2.3.1). + * ``heavily_compressed_attention``: :class:`DeepseekV4HCACompressor` — + high-compression non-overlapping-window pool, no indexer (paper §2.3.2). + + Block components (paper §2.3.3): + + * Shared-KV Multi-Query Attention: ``num_key_value_heads = 1``; ``kv_proj`` projects + directly to that single KV head and the same tensor is read as both key and + value. + * Partial RoPE on the first ``rope_head_dim`` of each head ("Partial Rotary + Positional Embedding"). RoPE is also applied with position ``-i`` to the + attention output's rope slice, so the contribution of each KV entry stays a + function of the *relative* distance to the query. + * RMSNorm on the queries (``q_norm``) and the compressed KV head (``kv_norm``) + right before the core attention, to keep logits bounded. + * Per-head learnable attention sink (eq. 27). + * Grouped low-rank output projection (§2.3.1, "Grouped Output Projection"): + ``g`` head-groups → ``d_g``-dim intermediate outputs through a block-diagonal + :class:`DeepseekV4GroupedLinear`, then mixed back to ``hidden_size`` by ``o_b_proj``. + * A supplementary uncompressed sliding-window KV branch of size + ``sliding_window`` ("Additional Branch of Sliding Window Attention") that + preserves local fine-grained dependencies, concatenated with the + long-range compressor's output before core attention. + """ + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.num_heads = config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads # single KV head, broadcast to all + self.head_dim = config.head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.sliding_window = config.sliding_window + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.scaling = self.head_dim**-0.5 + + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=False) + self.q_norm = DeepseekV4RMSNorm(config.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.head_dim, bias=False) + self.q_head_norm = DeepseekV4UnweightedRMSNorm(eps=config.rms_norm_eps) + self.kv_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False) + self.kv_norm = DeepseekV4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.o_a_proj = DeepseekV4GroupedLinear( + self.num_heads * self.head_dim // config.o_groups, config.o_groups * config.o_lora_rank, config.o_groups + ) + self.o_b_proj = nn.Linear(config.o_groups * config.o_lora_rank, config.hidden_size, bias=False) + self.sinks = nn.Parameter(torch.empty(self.num_heads)) + # Long-range branch dispatched by ``layer_type`` (see ``COMPRESSOR_CLASSES`` + # above). ``None`` means full-attention / sliding-only — no compressor is + # built and the layer keeps just the local sliding-window K=V branch. + compressor_cls = COMPRESSOR_CLASSES[self.layer_type] + self.compressor = compressor_cls(config) if compressor_cls is not None else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + position_ids: torch.Tensor, + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + batch, seq_len = hidden_states.shape[:2] + cos, sin = position_embeddings + q_residual = self.q_norm(self.q_a_proj(hidden_states)) + q = self.q_b_proj(q_residual).view(batch, seq_len, -1, self.head_dim).transpose(1, 2) + q = self.q_head_norm(q) + kv = self.kv_norm(self.kv_proj(hidden_states)).view(batch, seq_len, 1, self.head_dim).transpose(1, 2) + q = apply_rotary_pos_emb(q, cos, sin) + kv = apply_rotary_pos_emb(kv, cos, sin) + # Under TP, ``q_b_proj`` is colwise-sharded so ``q`` carries + # ``num_heads / tp_size`` heads per rank, while ``kv`` (single shared head) is + # replicated. Refresh ``num_key_value_groups`` from the local ``q`` so the + # standard ``repeat_kv(key, num_key_value_groups)`` in the attention backends + # lifts the single kv head to match the rank-local query head count. + self.num_key_value_groups = q.shape[1] + + # --- Sliding-window K=V branch goes through the standard cache update --- + if past_key_values is not None: + kv, _ = past_key_values.update(kv, kv, self.layer_idx) + + if self.compressor is None: + full_kv = kv + else: + compressed_kv = self.compressor(hidden_states, q_residual, position_ids, past_key_values, self.layer_idx) + full_kv = torch.cat([kv, compressed_kv], dim=2) + + # Compressor concatenates extra long-range KV entries onto the sliding-window + # branch so ``full_kv`` is wider than the model-level mask was built for. Pad + # the mask's key dim with ``0.0`` (additive-mask convention: unmasked) so the + # compressor positions are attended. FA / SDPA backends consume the same 4D + # additive mask path here, so the pad is uniform across backends. + if attention_mask is not None and full_kv.shape[2] > attention_mask.shape[-1]: + attention_mask = F.pad(attention_mask, (0, full_kv.shape[2] - attention_mask.shape[-1]), value=0.0) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + attn_output, attn_weights = attention_interface( + self, + q, + full_kv, + full_kv, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, + **kwargs, + ) + + # De-rotate the output's rope slice. V4 shares K and V (``kv_proj`` projects to a + # single tensor), so V's rope slice carries the same per-token rotation as K. + # Attention sums V-rotated values across attended positions, so the output's + # rope slice is a position-mixed content; conjugate rotation at the query + # position pulls it back into a position-independent frame before the output + # projection mixes heads. + attn_output = apply_rotary_pos_emb(attn_output.transpose(1, 2), cos, -sin).transpose(1, 2) + + grouped = attn_output.reshape(batch, seq_len, -1).view(batch, seq_len, self.config.o_groups, -1) + return self.o_b_proj(self.o_a_proj(grouped).flatten(2)), attn_weights + + +class DeepseekV4HyperConnection(nn.Module): + r""" + Manifold-Constrained Hyper-Connections + (mHC) (Xie et al., 2026) to strengthen the conventional residual connections between adjacent + Transformer blocks + + Owns the learned (``fn``, ``base``, ``scale``) + parameters that turn the incoming ``hc_mult`` residual streams into collapse / expand + weights. The decoder layer instantiates two of these (one for the attention site, + one for the mlp site). + + ASCII shape guide — ``B`` = batch, ``S`` = seq, ``H`` = hc_mult, ``D`` = hidden_size:: + + hidden_streams flatten(2) RMSNorm-rescale + F.linear(fn) + [B, S, H, D] ──────────► [B, S, H*D] ─────────────────────────────────► + mix-logits + [B, S, (2+H)*H] + │ + ┌───────────────────────────────────────┴──────────────────────────────┐ + ▼ ▼ ▼ + pre logits post logits comb logits + [B, S, H] [B, S, H] [B, S, H, H] + × scale[0] × scale[1] × scale[2] + + base[:H] + base[H:2H] + base[2H:] + σ() + eps σ() + eps σ() + eps + │ │ │ + pre post Sinkhorn(iters) + (stream collapse weights) (block-output placement) row/col normalise + │ + comb + (stream mixer) + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.hc_mult = config.hc_mult + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + self.input_norm = DeepseekV4UnweightedRMSNorm(eps=config.rms_norm_eps) + mix = (2 + self.hc_mult) * self.hc_mult + self.fn = nn.Parameter(torch.empty(mix, self.hc_mult * config.hidden_size)) + self.base = nn.Parameter(torch.empty(mix)) + # 3 = number of outputs from the mHC mapping: ``pre`` (input projection + # weights), ``post`` (sublayer output projection weights), ``comb`` (the + # H×H residual combine matrix that gets Sinkhorn-projected onto the + # doubly-stochastic manifold). Each output gets its own learned scale. + self.scale = nn.Parameter(torch.empty(3)) + + def forward(self, hidden_streams: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + project it onto the manifold of doubly stochastic matrices M. + This is achieved by the Sinkhorn-Knopp algorithm, which first applies an exponential function + ˜ + to + 𝐵𝑙 to ensure positivity, getting 𝑀(0) = exp(˜ + 𝐵𝑙), and then iteratively performs column and row + normalization: + 𝑀(𝑡) = T𝑟(T𝑐(𝑀(𝑡−1))), (8) + where T𝑟 and T𝑐 denote row and column normalization, respectively. + """ + flat = self.input_norm(hidden_streams.flatten(start_dim=2).float()) + mix = F.linear(flat, self.fn.float()) # [B, S, (2+H)*H] + pre_scale, post_scale, comb_scale = self.scale.unbind(0) + hc = self.hc_mult + pre = torch.sigmoid(mix[..., :hc] * pre_scale + self.base[:hc]) + self.hc_eps + post = torch.sigmoid(mix[..., hc : 2 * hc] * post_scale + self.base[hc : 2 * hc]) + self.hc_eps + comb = ( + torch.sigmoid( + mix[..., 2 * hc :].view(*mix.shape[:-1], hc, hc) * comb_scale + self.base[2 * hc :].view(hc, hc) + ) + + self.hc_eps + ) + for _ in range(self.hc_sinkhorn_iters): + comb = comb / (comb.sum(dim=-1, keepdim=True) + self.hc_eps) + comb = comb / (comb.sum(dim=-2, keepdim=True) + self.hc_eps) + # Collapse the ``hc_mult`` parallel streams down to a single sequence using + # the ``pre`` weights (Manifold-Constrained input projection): one weighted + # sum across the stream axis, ready for the sublayer (attn / MLP). + collapsed = (pre.unsqueeze(-1) * hidden_streams).sum(dim=2).to(hidden_streams.dtype) + return post, comb, collapsed + + +class DeepseekV4HyperHead(nn.Module): + """Final HC-stream collapse; used by ``DeepseekV4Model`` before the shared RMSNorm.""" + + def __init__(self, config: DeepseekV4Config): + super().__init__() + self.hc_mult = config.hc_mult + self.input_norm = DeepseekV4UnweightedRMSNorm(eps=config.rms_norm_eps) + self.eps = config.hc_eps + self.hc_fn = nn.Parameter(torch.empty(self.hc_mult, self.hc_mult * config.hidden_size)) + self.hc_base = nn.Parameter(torch.empty(self.hc_mult)) + self.hc_scale = nn.Parameter(torch.empty(1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + flat = self.input_norm(x.flatten(2).float()) + mixes = F.linear(flat, self.hc_fn.float()) + pre = torch.sigmoid(mixes * self.hc_scale.float() + self.hc_base.float()) + self.eps + return (pre.unsqueeze(-1) * x).sum(dim=2).to(x.dtype) + + +class DeepseekV4MLP(LlamaMLP): + pass + + +@use_experts_implementation +class DeepseekV4Experts(MixtralExperts): + """Routed experts (paper §2.1). Inherits the Mixtral layout (no biases, + ``[num_experts, 2 * intermediate, hidden]`` ``gate_up_proj``) and the per-expert + iteration loop; the only V4-specific bit is the ``swiglu_limit`` clamp on the + gate / up halves before the SiLU mix. + + ``config.intermediate_size`` resolves to ``moe_intermediate_size`` via the + config's ``attribute_map``, so ``MixtralExperts.__init__`` builds the right + Linear shapes without an override. + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__(config) + self.limit = config.swiglu_limit + + def forward( + self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + ) -> torch.Tensor: + final = torch.zeros_like(hidden_states) + with torch.no_grad(): + mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + hit = torch.greater(mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(mask[expert_idx]) + gate, up = F.linear(hidden_states[token_idx], self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + current = self.act_fn(gate) * up + current = F.linear(current, self.down_proj[expert_idx]) * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, current.to(final.dtype)) + return final + + +class DeepseekV4TopKRouter(MixtralTopKRouter): + """DeepSeekMoE top-k router (paper §2.1, "Mixture-of-Experts"). Two changes from + the V3 router: + + * The expert affinity activation is ``Sqrt(Softplus(·))`` instead of the V3 + Sigmoid (paper §2.1: "we change the activation function that computes the + affinity scores from Sigmoid(·) into Sqrt(Softplus(·))"). The ``scoring_func`` + config field selects this for V4 checkpoints. + * The constraint on the number of routing target nodes used in V3 is dropped, + and the V3 ``n_group`` / ``topk_group`` machinery is removed entirely (paper + §2.1: "we remove the constraint on the number of routing target nodes"). + + The auxiliary-loss-free strategy (DeepSeek's ``noaux_tc``) is preserved via the + per-expert ``e_score_correction_bias`` buffer that biases the top-k argmax + without flowing gradients. + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__(config) + self.score_fn = ACT2FN[config.scoring_func] + self.routed_scaling_factor = config.routed_scaling_factor + self.register_buffer("e_score_correction_bias", torch.zeros(self.num_experts), persistent=True) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = F.linear(flat.float(), self.weight.float()) + scores = self.score_fn(logits) + indices = torch.topk(scores + self.e_score_correction_bias, self.top_k, dim=-1, sorted=False).indices + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class DeepseekV4HashRouter(MixtralTopKRouter): + """Hash routing for the first ``mlp_layer_types == "hash_moe"`` MoE layers (paper + §2.1). Expert selection is determined by a fixed ``tid2eid[input_ids]`` lookup — + a frozen token-id → expert-id table — instead of a learned argmax. The learned + gate ``weight`` still produces the per-expert scores that weight the selected + experts' activations; only the *which-experts* selection is static. + """ + + def __init__(self, config: DeepseekV4Config): + super().__init__(config) + self.score_fn = ACT2FN[config.scoring_func] + self.routed_scaling_factor = config.routed_scaling_factor + self.register_buffer("tid2eid", torch.zeros(config.vocab_size, self.top_k, dtype=torch.long), persistent=True) + + def forward( + self, hidden_states: torch.Tensor, input_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + flat = hidden_states.reshape(-1, self.hidden_dim) + logits = F.linear(flat.float(), self.weight.float()) + scores = self.score_fn(logits) + indices = self.tid2eid[input_ids.reshape(-1)].long() + weights = scores.gather(1, indices) + weights = weights / (weights.sum(dim=-1, keepdim=True) + 1e-20) + return logits, weights * self.routed_scaling_factor, indices + + +class DeepseekV4SparseMoeBlock(nn.Module): + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.is_hash = config.mlp_layer_types[layer_idx] == "hash_moe" + self.gate = DeepseekV4HashRouter(config) if self.is_hash else DeepseekV4TopKRouter(config) + self.experts = DeepseekV4Experts(config) + self.shared_experts = DeepseekV4MLP(config) + + def forward(self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None) -> torch.Tensor: + batch, seq_len, hidden_dim = hidden_states.shape + residual = hidden_states + flat = hidden_states.view(-1, hidden_dim) + if self.is_hash: + if input_ids is None: + raise ValueError( + "DeepseekV4's hash-routing layers need `input_ids` to look up expert indices. " + "The `inputs_embeds`-only inference path is not supported for models with " + "any `hash_moe` entries in `mlp_layer_types`." + ) + _, weights, indices = self.gate(hidden_states, input_ids) + else: + _, weights, indices = self.gate(hidden_states) + routed = self.experts(flat, indices, weights).view(batch, seq_len, hidden_dim) + return routed + self.shared_experts(residual) + + +class DeepseekV4DecoderLayer(GradientCheckpointingLayer): + r"""DeepSeek-V4 decoder block (paper §2). Differs from a classic residual block in + two places: + + * The residual is a stack of ``hc_mult`` parallel streams kept in shape + ``[B, S, hc_mult, D]`` throughout the block, mixed in and out via two + :class:`DeepseekV4HyperConnection` modules (Manifold-Constrained Hyper- + Connections / mHC, paper §2.2; Xie et al., 2026). The mHC mappings constrain + the residual transform to the manifold of doubly-stochastic matrices via the + Sinkhorn-Knopp projection — making signal propagation non-expansive across + deep stacks. + * ``self_attn`` is :class:`DeepseekV4Attention` for every layer. Its compressor + sub-module is the only thing that varies by layer type + (:class:`DeepseekV4HCACompressor` for HCA layers, + :class:`DeepseekV4CSACompressor` for CSA, picked via + ``config.layer_types[layer_idx]``); the CSA compressor also owns the + Lightning Indexer at ``self_attn.compressor.indexer``. + + Classic residual decoder layer:: + + h ──► norm ──► self_attn ──► + ──► norm ──► mlp ──► + + └──────── residual ────────┘ └─────── residual ───┘ + + Deepseek V4 decoder layer (``H = hc_mult`` parallel residual streams throughout):: + + attention site mlp site + ┌────────────────────────────────────────┐ ┌────────────────────────────────────────┐ + │ hidden_streams [B, S, H, D] │ │ hidden_streams [B, S, H, D] │ + │ │ │ │ │ │ + │ attn_hc(streams) ─► (pre, post, comb) │ │ ffn_hc(streams) ─► (pre, post, comb) │ + │ │ │ │ │ │ + │ Σ pre·streams (collapse) │ │ Σ pre·streams (collapse) │ + │ │ │ │ │ │ + │ input_layernorm │ │ post_attention_layernorm │ + │ │ │ │ │ │ + │ self_attn │ │ mlp (MoE routed + shared) │ + │ │ │ │ │ │ + │ post·output + comb·streams (expand) │ │ post·output + comb·streams (expand) │ + │ │ │ │ │ │ + │ ▼ │ │ ▼ │ + │ new hidden_streams ──────────────────┘ │ new hidden_streams │ + └────────────────────────────────────────┘ └────────────────────────────────────────┘ + + + """ + + def __init__(self, config: DeepseekV4Config, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + self.self_attn = DeepseekV4Attention(config, layer_idx) + self.mlp = DeepseekV4SparseMoeBlock(config, layer_idx) + self.input_layernorm = DeepseekV4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn_hc = DeepseekV4HyperConnection(config) + self.ffn_hc = DeepseekV4HyperConnection(config) + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + # hidden_states throughout: [B, S, hc_mult, hidden]. + + # --- Attention site: collapse → norm → attn → expand --- + post, comb, collapsed = self.attn_hc(hidden_states) + attn_output, _ = self.self_attn(self.input_layernorm(collapsed), **kwargs) + dtype = hidden_states.dtype + hidden_states = post.to(dtype).unsqueeze(-1) * attn_output.unsqueeze(-2) + torch.matmul( + comb.to(dtype), hidden_states + ) + + # --- MLP site: collapse → norm → mlp → expand --- + post, comb, collapsed = self.ffn_hc(hidden_states) + mlp_output = self.mlp(self.post_attention_layernorm(collapsed), input_ids=input_ids) + dtype = hidden_states.dtype + return post.to(dtype).unsqueeze(-1) * mlp_output.unsqueeze(-2) + torch.matmul(comb.to(dtype), hidden_states) + + +# ----------------------------------------------------------------------------- +# Pre-trained base + Model + ForCausalLM. +# ----------------------------------------------------------------------------- + + +class DeepseekV4PreTrainedModel(MixtralPreTrainedModel): + config_class = DeepseekV4Config + base_model_prefix = "model" + _no_split_modules = ["DeepseekV4DecoderLayer"] + # V4 ships eager-only: the compressor / indexer paths weren't validated against + # SDPA / FlashAttention / FlexAttention kernels — leaving these ``False`` makes + # ``set_attn_implementation`` reject those backends instead of silently routing + # through them. + _supports_flash_attn = False + _supports_sdpa = False + _supports_flex_attn = False + # The compressor's rolling-window buffer / pool / overlap state lives on the + # per-layer cache (:class:`DeepseekV4HCACache` / :class:`DeepseekV4CSACache`) + # and isn't compatible with :class:`StaticCache` — that path would hand the + # compressor a :class:`StaticSlidingWindowLayer` with no ``update_compressor`` + # method. Disabling fullgraph compile keeps generation tests on the dynamic + # cache build that does dispatch to V4's own cache layers. + _can_compile_fullgraph = False + _keep_in_fp32_modules_strict = ["attn_hc", "ffn_hc"] + _keys_to_ignore_on_load_unexpected = [r"model\.mtp\..*"] + # ``_is_stateful`` opts out of generation modes that need to roll the cache + # back across drafts (assisted generation, prompt lookup, contrastive search). + # The compressor's running-window state isn't rewindable, so ``generate`` + # raises a clear error early instead of failing deep in the compressor with + # a missing-method ``AttributeError``. + _is_stateful = True + _can_record_outputs = { + "router_logits": OutputRecorder(DeepseekV4TopKRouter, index=0), + "hidden_states": DeepseekV4DecoderLayer, + "attentions": DeepseekV4Attention, + } + + @torch.no_grad() + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + std = self.config.initializer_range + if isinstance(module, (DeepseekV4TopKRouter, DeepseekV4HashRouter)): + init.normal_(module.weight, mean=0.0, std=std) + if isinstance(module, DeepseekV4TopKRouter): + init.zeros_(module.e_score_correction_bias) # buffer + if isinstance(module, DeepseekV4HashRouter): + init.zeros_(module.tid2eid) # buffer; real values come from the checkpoint + elif isinstance(module, DeepseekV4Experts): + init.normal_(module.gate_up_proj, mean=0.0, std=std) + init.normal_(module.down_proj, mean=0.0, std=std) + elif isinstance(module, DeepseekV4Attention): + init.zeros_(module.sinks) + elif isinstance(module, DeepseekV4HyperConnection): + init.normal_(module.fn, mean=0.0, std=std) + init.zeros_(module.base) + init.ones_(module.scale) + elif isinstance(module, DeepseekV4HyperHead): + init.normal_(module.hc_fn, mean=0.0, std=std) + init.zeros_(module.hc_base) + init.ones_(module.hc_scale) + elif isinstance(module, (DeepseekV4HCACompressor, DeepseekV4CSACompressor, DeepseekV4Indexer)): + init.zeros_(module.position_bias) + elif isinstance(module, DeepseekV4RotaryEmbedding): + for layer_type in module.layer_types: + rope_init_fn = module.compute_default_rope_parameters + if module.rope_type[layer_type] != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]] + curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type) + init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq) + init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq) + + +@auto_docstring +class DeepseekV4Model(LlamaModel): + def __init__(self, config: DeepseekV4Config): + super().__init__(config) + self.layers = nn.ModuleList( + [DeepseekV4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.rotary_emb = DeepseekV4RotaryEmbedding(config) + self.hc_head = DeepseekV4HyperHead(config) + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + return_cache = past_key_values if use_cache else None + if past_key_values is None: + past_key_values = DynamicCache(config=self.config) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: + past_seen = past_key_values.get_seq_length() + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen + position_ids = position_ids.unsqueeze(0) + # ``generate()`` may pass a per-layer-type mask dict already built by + # ``create_masks_for_generate``; all V4 layer types use the same sliding-window + # mask, so use the prebuilt one directly. Otherwise build it here. + if isinstance(attention_mask, dict): + causal_mask = next(iter(attention_mask.values())) + else: + causal_mask = create_sliding_window_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1).contiguous() + position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main") + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + position_ids=position_ids, + attention_mask=causal_mask, + input_ids=input_ids, + past_key_values=past_key_values, + **kwargs, + ) + + hidden_states = self.norm(self.hc_head(hidden_states)) + return MoeModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=return_cache) + + +class DeepseekV4ForCausalLM(MixtralForCausalLM): + pass + + +__all__ = [ + "DeepseekV4PreTrainedModel", + "DeepseekV4Model", + "DeepseekV4ForCausalLM", +] diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 5390a9c3e8d3..fbebea2977be 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -294,6 +294,19 @@ def get_quantize_ops(self): def get_weight_conversions(self): return [] + def update_weight_conversions(self, weight_conversions): + """Give the quantizer a chance to rewrite the weight conversion pipeline. + + Loading runs ``renamings → converters → (dequant → merge → concat)``. Dequant + has to happen *before* any merge/concat op because those operations aren't + aware of per-block scales, so the per-expert (weight, scale) pairs need to be + collapsed into full-precision tensors first. Subclasses (e.g. the FP8 + quantizer in ``dequantize=True`` mode) override this to inject a dequantize + op at the start of each model-provided :class:`WeightConverter` and attach the + matching scale source patterns. Default: no-op. + """ + return weight_conversions + self.get_weight_conversions() + class SequentialLlama4TextExperts(ModuleList): """ diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index e736b0f21915..be10624d4842 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -166,3 +166,61 @@ def get_weight_conversions(self): ) ] return [] + + def update_weight_conversions(self, weight_conversions): + """When loading with ``dequantize=True``, attach an :class:`Fp8Dequantize` op to + every existing :class:`WeightConverter` so that per-block scales are folded into + the weight *before* any later merge/concat ops collapse the per-expert structure. + + For each model-supplied converter that has a ``.weight`` source, we: + 1. anchor the existing weight patterns with ``$`` so they don't accidentally + also match the ``.weight_scale_inv`` keys (the regex is searched, so the + unanchored prefix would match both, sending scales to the wrong bucket); + 2. add anchored ``*.weight_scale_inv`` sources next to each weight pattern so + the loader collects scale tensors alongside the weight tensors into the + *same* converter bucket (both keys rewrite to the same target); + 3. prepend a fresh :class:`Fp8Dequantize` op so dequant runs first, before + any merge/concat collapses the per-expert structure. + + The generic ``weight$ + weight_scale_inv → weight`` converter from + :meth:`get_weight_conversions` is still appended at the end as a fallback for + plain ``nn.Linear`` weights with no model-specific converter. + """ + if not (self.pre_quantized and self.quantization_config.dequantize): + return weight_conversions + self.get_weight_conversions() + + from ..core_model_loading import WeightConverter, WeightRenaming + from ..integrations.finegrained_fp8 import Fp8Dequantize + + # Some upstream FP8 checkpoints (e.g. DeepSeek-V4-Flash) ship per-block scales + # under a ``.scale`` suffix instead of HF's canonical ``.weight_scale_inv``. + # Prepending the rename here (instead of in each model's conversion_mapping) + # keeps the model-side mapping clean — the rename only kicks in when FP8 dequant + # is actually active, so a non-FP8 save / load round-trip doesn't see a stray + # rule that ``test_reverse_loading_mapping`` can't match. + scale_rename = WeightRenaming(source_patterns=r"^(.+)\.scale$", target_patterns=r"\1.weight_scale_inv") + weight_conversions = [scale_rename] + list(weight_conversions) + + updated: list = [] + for conv in weight_conversions: + # Only WeightConverter has ``.operations`` to extend with the dequant op; + # WeightRenaming (e.g. the ``scale_rename`` we prepended) just passes through. + if not isinstance(conv, WeightConverter): + updated.append(conv) + continue + weight_sources = [p for p in conv.source_patterns if p.endswith(".weight")] + if weight_sources: + anchored_weight = [p + "$" for p in weight_sources] + scale_sources = [p[: -len(".weight")] + ".weight_scale_inv$" for p in weight_sources] + other = [p for p in conv.source_patterns if not p.endswith(".weight")] + new_sources = anchored_weight + scale_sources + other + new_ops = [Fp8Dequantize(self)] + list(conv.operations) + conv = WeightConverter( + source_patterns=new_sources, + target_patterns=conv._original_target_patterns, + operations=new_ops, + ) + updated.append(conv) + # Generic fallback for plain ``nn.Linear`` weights with no model-specific converter. + updated.extend(self.get_weight_conversions()) + return updated diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index bf085d87498c..a6c1f5334516 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1716,7 +1716,7 @@ def post_init(self): raise ValueError("weight_block_size must be a tuple of two positive integers") def get_loading_attributes(self): - return {"dequantize": self.dequantize} + return {"dequantize": self.dequantize, "modules_to_not_convert": self.modules_to_not_convert} class QuarkConfig(QuantizationConfigMixin): diff --git a/tests/models/deepseek_v4/__init__.py b/tests/models/deepseek_v4/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/deepseek_v4/test_modeling_deepseek_v4.py b/tests/models/deepseek_v4/test_modeling_deepseek_v4.py new file mode 100644 index 000000000000..7d6172296d0c --- /dev/null +++ b/tests/models/deepseek_v4/test_modeling_deepseek_v4.py @@ -0,0 +1,421 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +import unittest + +from parameterized import parameterized + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device + + +if is_torch_available(): + import torch + + from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + DeepseekV4Config, + DeepseekV4ForCausalLM, + DeepseekV4Model, + DynamicCache, + FineGrainedFP8Config, + ) + from transformers.models.deepseek_v4.modeling_deepseek_v4 import DeepseekV4HCACompressor + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class DeepseekV4ModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = DeepseekV4Model + + def __init__(self, parent, **kwargs): + # ``CausalLMModelTester.__init__`` assigns a fixed set of attributes from its + # keyword defaults (``hidden_size``, ``num_attention_heads`` and friends); those + # overwrite any class-level attributes of the same name. Pass V4 defaults through + # ``kwargs`` so the tester instance reflects V4's shape. + kwargs.setdefault("hidden_size", 64) + kwargs.setdefault("num_attention_heads", 4) + kwargs.setdefault("num_key_value_heads", 1) + kwargs.setdefault("num_hidden_layers", 2) + kwargs.setdefault("num_experts_per_tok", 2) + kwargs.setdefault("moe_intermediate_size", 64) + kwargs.setdefault("max_position_embeddings", 64) + super().__init__(parent, **kwargs) + # V4-only attributes that ``CausalLMModelTester.get_config`` will pull by name. + self.head_dim = 32 + self.partial_rotary_factor = 8 / 32 # qk_rope_head_dim=8 / head_dim=32 + self.q_lora_rank = 32 + self.o_groups = 2 + self.o_lora_rank = 16 + self.n_routed_experts = 4 + self.n_shared_experts = 1 + # All ``"moe"`` (no ``"hash_moe"``) so the ``inputs_embeds``-only generation + # tests in ``CausalLMModelTest`` can exercise the model without running into + # the hash router's ``input_ids`` requirement. A dedicated test covers the + # hash path. + self.mlp_layer_types = ["moe", "moe"] + self.layer_types = ["heavily_compressed_attention", "compressed_sparse_attention"] + self.sliding_window = 8 + self.hc_mult = 2 + self.hc_sinkhorn_iters = 3 + self.hc_eps = 1.0e-6 + self.index_n_heads = 2 + self.index_head_dim = 16 + self.index_topk = 2 + self.num_nextn_predict_layers = 0 + self.scoring_func = "sqrtsoftplus" + self.routed_scaling_factor = 1.5 + self.swiglu_limit = 10.0 + self.rope_theta = 10000.0 + self.compress_rope_theta = 160000.0 + self.attention_bias = False + self.attention_dropout = 0.0 + + +@require_torch +class DeepseekV4ModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = DeepseekV4ModelTester + + # Indexer parameters only influence the argmax over compressed positions (``topk``), + # which is non-differentiable — their gradients flow through a separate objective in + # the upstream training recipe, not the main causal-LM loss. + test_all_params_have_gradient = False + + # No SequenceClassification / TokenClassification / QA heads on V4. + def is_pipeline_test_to_skip(self, *args, **kwargs): + return True + + def _check_attentions_for_generate( + self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values + ): + # V4 layers with a Compressor attend to extra pooled positions, so the KV + # length varies per layer. We only check the shape invariants: batched, same + # number-of-heads and query-length; the KV-length axis may differ across layers. + import torch # noqa: PLC0415 + + self.assertIsInstance(attentions, tuple) + self.assertEqual(len(attentions), (output_length - prompt_length)) + for _, iter_attentions in enumerate(attentions): + self.assertIsInstance(iter_attentions, tuple) + for layer_attention in iter_attentions: + self.assertIsInstance(layer_attention, torch.Tensor) + self.assertEqual(layer_attention.shape[0], batch_size) + self.assertEqual(layer_attention.shape[1], config.num_attention_heads) + + @unittest.skip( + "V4's rotary uses per-layer-type inv_freq buffers (Gemma3 pattern); the common test calls forward without `layer_type` and reads `.inv_freq`, neither of which apply." + ) + def test_model_rope_scaling_frequencies(self): + pass + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + @unittest.skip( + "V4's rotary uses per-layer-type rope_parameters; the common test sets a flat dict and skips for multi-layer-type rotaries." + ) + def test_model_rope_scaling_from_config(self, scaling_type): + pass + + def test_hidden_states_output(self): + # V4 layers emit a 4D ``[B, S, hc_mult, hidden]`` tensor — the hc_mult streams + # are only collapsed at the top of the model via ``hc_head``. The common + # ``test_hidden_states_output`` assumes ``(batch, seq, hidden)``; we re-run the + # same check but accept the extra HC axis, and we additionally assert the final + # (post-hc_head) ``last_hidden_state`` has the standard 3D shape. + import torch # noqa: PLC0415 + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device).eval() + with torch.no_grad(): + outputs = model(**inputs_dict) + hidden_states = outputs.hidden_states if hasattr(outputs, "hidden_states") else outputs[-1] + self.assertIsNotNone(hidden_states) + self.assertEqual(len(hidden_states), config.num_hidden_layers + 1) + seq_len = inputs_dict["input_ids"].shape[1] + for layer_h in hidden_states: + # Accept either the collapsed (3D) post-head shape or the per-layer 4D shape. + if layer_h.ndim == 3: + self.assertEqual(layer_h.shape, (inputs_dict["input_ids"].shape[0], seq_len, config.hidden_size)) + else: + self.assertEqual( + layer_h.shape, + (inputs_dict["input_ids"].shape[0], seq_len, config.hc_mult, config.hidden_size), + ) + + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): + # Every V4 layer is sliding-window, so the cache is length-bounded to + # ``sliding_window`` instead of the full ``seq_length`` the parent tester expects. + # We also accept the compressed-segment positions that ``DeepseekV4Attention`` + # appends on compress layers (they live beyond the window on the keys axis). + import torch # noqa: PLC0415 + + num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = config.head_dim + for layer in past_key_values.layers: + keys, values = layer.keys, layer.values + self.assertIsInstance(keys, torch.Tensor) + self.assertEqual(keys.shape[0], batch_size) + self.assertEqual(keys.shape[1], num_kv_heads) + self.assertEqual(keys.shape[3], head_dim) + self.assertEqual(keys.shape, values.shape) + + @unittest.skip( + reason=( + "V4's conversion mapping is two-pass: a structural prefix rename " + "(``layers.X.attn.`` → ``model.layers.X.self_attn.``) runs first, then specific in-prefix " + "renames operate on the already-prefixed HF-form keys (``model.layers.X.self_attn.compressor.norm.`` " + "→ ``...compressor.kv_norm.``). This split is load-bearing for save / load round-tripping — " + "any single-pass ordering loses information in either direction (the general prefix rule " + "and a specific in-prefix rule both want to match the same upstream key, and one of the " + "two directions ends up with the general rule stealing the match). The base " + "``test_reverse_loading_mapping`` checks every source pattern against the *upstream-form* " + "serialized keys, so the Pass 2 patterns (written in HF form) inherently can't satisfy " + "that invariant. The actual round-trip is exercised by ``test_save_load``." + ) + ) + def test_reverse_loading_mapping(self): + pass + + @unittest.skip( + reason=( + "V4's compressor pools windows of ``compress_rate`` consecutive tokens *before* the " + "attention mask is applied — left-padding shifts the window boundaries so pad tokens " + "get folded into the pooled KV entries, and the resulting logits diverge from the " + "unpadded run by design (same fundamental limitation as RecurrentGemma)." + ) + ) + def test_left_padding_compatibility(self): + pass + + def _check_hidden_states_for_generate( + self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False + ): + # V4's per-layer hidden states carry an extra ``hc_mult`` dim (Hyper-Connection + # parallel streams). We skip the exact seq-length assertion the base tester does, + # because assisted-decoding feeds arbitrary draft-token batches in, and just + # sanity-check batch / hidden dims. + import torch # noqa: PLC0415 + + self.assertIsInstance(hidden_states, tuple) + self.assertEqual(len(hidden_states), (output_length - prompt_length)) + for iter_hidden_states in hidden_states: + self.assertIsInstance(iter_hidden_states, tuple) + for layer_hidden in iter_hidden_states: + self.assertIsInstance(layer_hidden, torch.Tensor) + self.assertEqual(layer_hidden.shape[0], batch_size) + self.assertEqual(layer_hidden.shape[-1], config.hidden_size) + + +def _tiny_config(**overrides): + """Smallest V4 config that still exercises every architectural piece: HC streams + (``hc_mult=2``), hash routing (layer 0), a local-SWA layer, a compressor-with- + indexer layer (ratio 4), and a routed MoE with a shared expert. + """ + defaults = { + "vocab_size": 32, + "hidden_size": 32, + "head_dim": 16, + "partial_rotary_factor": 4 / 16, # qk_rope_head_dim=4 / head_dim=16 + "q_lora_rank": 16, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "layer_types": ["heavily_compressed_attention", "compressed_sparse_attention"], + "sliding_window": 4, + "hc_mult": 2, + "hc_sinkhorn_iters": 3, + "hc_eps": 1e-6, + "moe_intermediate_size": 32, + "n_routed_experts": 4, + "n_shared_experts": 1, + "num_experts_per_tok": 2, + "mlp_layer_types": ["hash_moe", "moe"], + "scoring_func": "sqrtsoftplus", + "routed_scaling_factor": 1.0, + "swiglu_limit": 10.0, + "o_groups": 2, + "o_lora_rank": 8, + "index_n_heads": 2, + "index_head_dim": 8, + "index_topk": 2, + "num_nextn_predict_layers": 0, + "max_position_embeddings": 32, + "rope_theta": 10000.0, + "compress_rope_theta": 10000.0, # match main rope for a cleaner parity check + "attention_bias": False, + "attention_dropout": 0.0, + } + defaults.update(overrides) + return DeepseekV4Config(**defaults) + + +@require_torch +class DeepseekV4ParityTest(unittest.TestCase): + """Functional sanity checks against tiny-config reference implementations of the + V4-specific pieces (compressor pooling, HC mix + collapse). These re-derive the + math from the upstream ``inference/model.py`` and compare to our HF modules, so a + regression in the packed cache / HC / pool code would surface here numerically. + """ + + def test_compressor_pool_matches_reference(self): + """Re-implement the reference ``Compressor._pool`` (softmax-gated sum-pool with + a learned ``position_bias``) and check it matches what the V4 + :class:`DeepseekV4HCACache` + :class:`DeepseekV4HCACompressor` produce inline. + """ + torch.manual_seed(0) + batch, length, head_dim, rate = 2, 8, 16, 4 + kv = torch.randn(batch, length, head_dim) + gate = torch.randn(batch, length, head_dim) + position_bias = torch.randn(rate, head_dim) + + # Reproduce the V4 in-line pool from ``DeepseekV4HCACompressor._pool``. + n_windows = length // rate + view_kv = kv.view(batch, n_windows, rate, head_dim) + view_gate = gate.view(batch, n_windows, rate, head_dim) + position_bias.to(gate.dtype) + ours = (view_kv * view_gate.softmax(dim=2)).sum(dim=2) + + # Reference (transcribed from upstream ``inference/model.py``). + reference = torch.zeros(batch, n_windows, head_dim) + for b in range(batch): + for i in range(n_windows): + window_kv = kv[b, i * rate : (i + 1) * rate] + window_gate = gate[b, i * rate : (i + 1) * rate] + position_bias + w = torch.softmax(window_gate, dim=0) + reference[b, i] = (window_kv * w).sum(dim=0) + + torch.testing.assert_close(ours, reference, rtol=1e-5, atol=1e-6) + + def test_compressor_cache_accumulates_across_calls(self): + """Feeding the HCA compressor one token at a time must produce the same pool + as feeding the whole sequence. Using HCA keeps the test indexer-free. + """ + torch.manual_seed(1) + config = _tiny_config( + layer_types=["heavily_compressed_attention", "heavily_compressed_attention"], + sliding_window=128, + max_position_embeddings=512, + compress_rates={"compressed_sparse_attention": 4, "heavily_compressed_attention": 128}, + ) + compressor = DeepseekV4HCACompressor(config).eval() + # Initialise ``position_bias`` to non-zero so the test exercises the pooling math. + torch.nn.init.normal_(compressor.position_bias, std=0.1) + + batch, seq_len = 1, 256 # two full windows + hidden_states = torch.randn(batch, seq_len, config.hidden_size) + position_ids = torch.arange(seq_len).unsqueeze(0) + + cache_full = DynamicCache(config=config) + with torch.no_grad(): + one_shot = compressor(hidden_states, None, position_ids, cache_full, 1) + + cache_inc = DynamicCache(config=config) + with torch.no_grad(): + for step in range(seq_len): + incremental = compressor(hidden_states[:, step : step + 1], None, torch.tensor([[step]]), cache_inc, 1) + self.assertEqual(one_shot.shape, incremental.shape) + torch.testing.assert_close(one_shot, incremental, rtol=1e-4, atol=1e-5) + + def test_tiny_forward_is_deterministic_and_finite(self): + """End-to-end smoke: tiny ``DeepseekV4ForCausalLM`` forward produces finite + logits of the right shape, and is deterministic under the same seed.""" + torch.manual_seed(42) + config = _tiny_config() + model = DeepseekV4ForCausalLM(config).eval() + + torch.manual_seed(0) + input_ids = torch.randint(0, config.vocab_size, (2, 10)) + with torch.no_grad(): + out_a = model(input_ids).logits + out_b = model(input_ids).logits + + self.assertEqual(out_a.shape, (2, 10, config.vocab_size)) + self.assertTrue(torch.isfinite(out_a).all()) + torch.testing.assert_close(out_a, out_b) # deterministic + + def test_tiny_generate_runs(self): + """Greedy-generate 4 new tokens on top of a 6-token prompt and check we get 10 + tokens out. Exercises the full generation loop: cache adopt, window cache, + compressor state, HC, indexer gather.""" + torch.manual_seed(42) + config = _tiny_config() + model = DeepseekV4ForCausalLM(config).eval() + + torch.manual_seed(0) + input_ids = torch.randint(0, config.vocab_size, (1, 6)) + # ``eos_token_id=-1`` keeps the freshly initialised random model from EOS-stopping + # before max_new_tokens, so the shape assertion is deterministic. + with torch.no_grad(): + out = model.generate(input_ids, max_new_tokens=4, do_sample=False, eos_token_id=-1) + self.assertEqual(out.shape, (1, 10)) + self.assertTrue(torch.isfinite(out.float()).all()) + + +@require_torch +@require_torch_accelerator +@slow +class DeepseekV4IntegrationTest(unittest.TestCase): + """End-to-end check on the published DeepSeek-V4-Flash checkpoint. + + Loads the real 43-layer FP8 weights, dequantizes on the fly via + :class:`FineGrainedFP8Config`, and greedy-generates a continuation of a fixed + prompt. The forward path that this test covers is everything past the typical + tiny-config tests can reach: the per-layer FP8 dequant in + ``update_weight_conversions``, the ``compress_ratios → layer_types`` config + translation (sliding / CSA / HCA), the ``coff=2`` overlap-window pooling on CSA + layers and the indexer's inner pool, the per-head Q rescale in + :class:`DeepseekV4Attention`, the YaRN-blended ``compress_rope_theta`` in the + compressor, the trailing-rope partial-RoPE convention, and the cross-layer + Hyper-Connection signal propagation. Any regression in those would tip + generation back into a single-token collapse or pure ```` output (the + failure modes we hit while landing the architecture). + + Marked ``@slow`` because the checkpoint is ~700 GB on disk and only loadable + on a multi-GPU host (``device_map="auto"`` plus FP8 dequant materializes the + weights in bf16). Run manually with:: + + RUN_SLOW=1 pytest tests/models/deepseek_v4/test_modeling_deepseek_v4.py::DeepseekV4IntegrationTest -k generation -s + """ + + model_id = "deepseek-ai/DeepSeek-V4-Flash" + prompt = "Pipeline parallelism in ai is " + + def test_v4_flash_fp8_generation(self): + # ``dequantize=True`` so we can run on bf16-only kernels (needed for the + # ``grouped_mm`` path the routed experts hit). Eager attention so we + # exercise the same forward we tune the rest of the V4 modeling around. + quantization_config = FineGrainedFP8Config(dequantize=True) + config = AutoConfig.from_pretrained(self.model_id) + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + model = AutoModelForCausalLM.from_pretrained( + self.model_id, + config=config, + dtype="auto", + device_map="auto", + attn_implementation="eager", + quantization_config=quantization_config, + ) + + inputs = tokenizer(self.prompt, return_tensors="pt").to(model.device) + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=64, do_sample=False) + + # Snapshot of greedy-decoded text. The exact continuation is deterministic + # under ``do_sample=False`` for a fixed prompt — if this snapshot drifts, + # something in the V4 forward / RoPE / Q-rescale / HC stack changed. + expected = ( + "Pipeline parallelism in ai is driven by three key factors: the exponential increase in data " + "size, the development of increasingly powerful computational techniques (especially deep " + "learning), to handle this data, and the availability of massive computational resources on " + "which to run these methods, all of which are are well aligned with trends in industry, " + " academia and research" + ) + decoded = tokenizer.decode(output_ids[0], skip_special_tokens=False) + self.assertEqual(decoded, expected) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 873c46791b1b..7a0e71e77d33 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -106,6 +106,32 @@ "HiggsAudioV2TokenizerConfig": ["downsample_factor"], "CsmConfig": ["tie_codebooks_embeddings"], "DeepseekV2Config": ["norm_topk_prob"], + "DeepseekV4Config": [ + "attention_bias", + "compress_rates", + "compress_rope_theta", + "hc_mult", + "hc_sinkhorn_iters", + "hc_eps", + "index_n_heads", + "index_head_dim", + "index_topk", + "mlp_layer_types", + "num_key_value_heads", + "num_nextn_predict_layers", + "norm_topk_prob", + "o_groups", + "o_lora_rank", + "q_lora_rank", + "rope_parameters", + "rope_theta", + "routed_scaling_factor", + "router_jitter_noise", + "scoring_func", + "n_routed_experts", + "n_shared_experts", + "swiglu_limit", + ], "EsmFoldConfig": ["esm_ablate_pairwise", "esm_ablate_sequence", "esm_input_dropout", "esm_type"], "TrunkConfig": ["cpu_grad_checkpoint", "layer_drop"], "SeamlessM4TConfig": True,