Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@
title: DeepSeek-V2
- local: model_doc/deepseek_v3
title: DeepSeek-V3
- local: model_doc/deepseek_v4
title: DeepSeek-V4
- local: model_doc/dialogpt
title: DialoGPT
- local: model_doc/diffllama
Expand Down
39 changes: 39 additions & 0 deletions docs/source/en/model_doc/deepseek_v4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<!--Copyright 2026 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contains specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->
*This model was released on {release_date} and added to Hugging Face Transformers on 2026-04-28.*

# DeepSeek-V4

[DeepSeek-V4](https://huggingface.co/deepseek-ai) is a family of MoE language models released by DeepSeek. Relative
to DeepSeek-V3, V4 replaces MLA with sliding-window attention plus a per-layer KV Compressor, swaps residual
connections for Hyper-Connections, routes the first few layers via a static token-id hash, and drops expert groups.

This implementation covers the `DeepSeek-V4-Flash`, `DeepSeek-V4-Pro`, and their `-Base` pretrained siblings. All
four share the same architecture; they differ only in width / depth / expert count and weights.

## DeepseekV4Config

[[autodoc]] DeepseekV4Config

## DeepseekV4Model

[[autodoc]] DeepseekV4Model
- forward

## DeepseekV4ForCausalLM

[[autodoc]] DeepseekV4ForCausalLM
- forward
8 changes: 8 additions & 0 deletions src/transformers/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ def forward(self, input):
return squared


class SqrtSoftplusActivation(nn.Module):
"""sqrt(softplus(x)) — the router scoring function used by DeepSeek V4."""

def forward(self, input):
return nn.functional.softplus(input).sqrt()


class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
Expand Down Expand Up @@ -334,6 +341,7 @@ def forward(self, input: Tensor) -> Tensor:
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"sqrtsoftplus": SqrtSoftplusActivation,
"swish": nn.SiLU,
"tanh": nn.Tanh,
"prelu": nn.PReLU,
Expand Down
83 changes: 66 additions & 17 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,31 @@
logger = logging.get_logger(__name__)


# Registry mapping ``config.layer_types[i]`` -> the dynamic cache layer class to build for
# that layer. ``DynamicCache.__init__`` consults this mapping when a ``config`` is provided
# so models with custom layer types (e.g. DeepSeek-V4's CSA / HCA) can register their own
# cache-layer subclass and stop needing a model-specific ``Cache`` subclass.
#
# A cache layer subclass with a class attribute ``layer_type = "..."`` auto-registers via
# ``CacheLayerMixin.__init_subclass__``. Each registered class must accept a
# ``PreTrainedConfig`` (the decoder text config) as the only positional argument.
LAYER_TYPE_CACHE_MAPPING: dict[str, type] = {}


class CacheLayerMixin(ABC):
"""Base, abstract class for a single layer's cache."""

is_compileable = False
# Subclasses can set ``layer_type`` to auto-register themselves in
# ``LAYER_TYPE_CACHE_MAPPING`` at import time (used by ``DynamicCache`` to dispatch
# per-layer cache classes from ``config.layer_types``).
layer_type: str | None = None

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
layer_type = cls.__dict__.get("layer_type", None)
if layer_type is not None:
LAYER_TYPE_CACHE_MAPPING[layer_type] = cls

def __init__(self):
self.keys: torch.Tensor | None = None
Expand Down Expand Up @@ -93,6 +114,9 @@ class DynamicLayer(CacheLayerMixin):

is_sliding = False

def __init__(self, config: PreTrainedConfig | None = None):
super().__init__()

def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None:
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
Expand Down Expand Up @@ -171,8 +195,14 @@ class DynamicSlidingWindowLayer(DynamicLayer):

is_sliding = True

def __init__(self, sliding_window: int):
def __init__(self, config: PreTrainedConfig | None = None, sliding_window: int | None = None):
super().__init__()
# Accept either a config (registry-style construction via LAYER_TYPE_CACHE_MAPPING)
# or a raw ``sliding_window`` int (legacy callers).
if sliding_window is None:
if config is None:
raise ValueError("Either `config` or `sliding_window` must be provided.")
sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None)
self.sliding_window = sliding_window
self.cumulative_length = 0
self._sliding_window_tensor = torch.tensor(self.sliding_window, dtype=torch.long)
Expand Down Expand Up @@ -732,6 +762,9 @@ def crop(self, max_length: int):


class LinearAttentionLayer(LinearAttentionCacheLayerMixin):
def __init__(self, config: PreTrainedConfig | None = None):
super().__init__()

def lazy_initialization(
self, conv_states: torch.Tensor | None = None, recurrent_states: torch.Tensor | None = None
) -> None:
Expand Down Expand Up @@ -808,7 +841,7 @@ class LinearAttentionAndFullAttentionLayer(LinearAttentionLayer, DynamicLayer):
# The dynamic Attention part makes it non-compileable
is_compileable = False

def __init__(self):
def __init__(self, config: PreTrainedConfig | None = None):
DynamicLayer.__init__(self)
LinearAttentionLayer.__init__(self)

Expand All @@ -831,6 +864,29 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
DynamicLayer.reorder_cache(self, beam_idx)


# Pre-register the standard layer types (some classes are shared between multiple types,
# e.g. ``DynamicSlidingWindowLayer`` covers both ``"sliding_attention"`` and
# ``"chunked_attention"`` — those need an explicit map entry rather than the
# auto-registration via ``CacheLayerMixin.__init_subclass__``).
LAYER_TYPE_CACHE_MAPPING.update(
{
"full_attention": DynamicLayer,
# From a cache point of view, sliding and chunked are the same in how they should behave;
# only the mask differs.
"sliding_attention": DynamicSlidingWindowLayer,
"chunked_attention": DynamicSlidingWindowLayer,
# Linear-attention-shaped layers (mamba / conv / pure linear-attention / moe placeholders)
# don't grow per-token KV; they're tracked just so position bookkeeping stays consistent.
"mamba": LinearAttentionLayer,
"conv": LinearAttentionLayer,
"linear_attention": LinearAttentionLayer,
"moe": LinearAttentionLayer,
# Hybrid layers (e.g. zamba / zamba2) carry both a linear-attention state and a dynamic-attention state.
"hybrid": LinearAttentionAndFullAttentionLayer,
}
)


class Cache:
"""
A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
Expand Down Expand Up @@ -1240,20 +1296,13 @@ def __init__(
layer_types = layer_types[: -decoder_config.num_kv_shared_layers]

for layer_type in layer_types:
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
# states they should return - only the mask changes to make them different at the end!
if layer_type in ("sliding_attention", "chunked_attention"):
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
# Note: we want moe layers to be LinearAttentionLayer, so that we can correctly grab sequence length etc from attention layers.
# Since moe layers will stay empty (they don't need any cache), we don't want them to collide for mask creation etc
# TODO: maybe use a dummy layer in those cases, or a dictionary {idx: Layer} for self.layers, so that we can skip
# the indices we don't need
elif layer_type in ("mamba", "conv", "linear_attention", "moe"):
layers.append(LinearAttentionLayer())
elif layer_type == "hybrid":
layers.append(LinearAttentionAndFullAttentionLayer())
else:
layers.append(DynamicLayer())
# Dispatch through the registry — ``LAYER_TYPE_CACHE_MAPPING`` ships with the
# standard layer types pre-registered, and models with custom layer types
# (e.g. DeepSeek-V4's CSA / HCA) register their own classes there. Each class
# is instantiated with the decoder config so it can read whatever attributes
# it needs (sliding_window, compress_rate, ...).
cache_cls = LAYER_TYPE_CACHE_MAPPING.get(layer_type, DynamicLayer)
layers.append(cache_cls(decoder_config))

# In this case, use the passed data to already fill in the Cache
if ddp_cache_data is not None:
Expand Down Expand Up @@ -1353,7 +1402,7 @@ def __init__(

layers = []
for layer_type in layer_types:
if layer_type == "sliding_attention":
if layer_type in ("sliding_attention", "compressed_sparse_attention", "heavily_compressed_attention"):
layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
elif layer_type == "chunked_attention":
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
"full_attention",
"sliding_attention",
"chunked_attention",
"compressed_sparse_attention", # CSA, used in deepseek_v4
"heavily_compressed_attention", # HCA, used in deepseek_v4
"linear_attention", # used in minimax
"conv", # used in LFMv2
"mamba",
Expand Down
129 changes: 127 additions & 2 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,128 @@ def _build_checkpoint_conversion_mapping():
"altclip": [
WeightRenaming(source_patterns=r"layer\.", target_patterns="layers."),
],
"deepseek_v4": [
# Upstream checkpoint uses a flatter, V3-style namespace: ``attn`` / ``ffn``
# instead of ``self_attn`` / ``mlp``, ``attn_norm`` / ``ffn_norm`` instead of
# ``input_layernorm`` / ``post_attention_layernorm``, ``hc_attn_*`` / ``hc_ffn_*``
# for the Hyper-Connection params (we wrap them in ``attn_hc`` / ``ffn_hc``
# submodules), ``embed`` / ``head`` / bare ``norm`` for the model head, and
# ``hc_head_*`` for the final HC collapse. The Indexer's compressor tree is
# nested under ``attn.indexer.compressor.*`` upstream but flattened onto the
# Indexer module here. FP8 scales arrive as ``.scale`` and need to become
# ``.weight_scale_inv`` to match :class:`FineGrainedFP8Linear`.
#
# Ordering matters for save round-tripping: :func:`revert_weight_conversion`
# reverses the order *and* each transform, so a structural prefix-only rule
# placed before a specific in-prefix rename would steal the reverse match
# and emit ``layers.X.attn.sinks`` instead of ``layers.X.attn.attn_sink``.
# We split into two passes: structural prefix renames first (so they apply
# last on save / first on load), then specific in-prefix renames that
# operate on the already-prefixed keys.
#
# FP8 ``.scale`` → ``.weight_scale_inv`` rename lives in the FP8 quantizer's
# ``update_weight_conversions`` (only kicks in when FP8 dequant is active),
# so the V4 static mapping below stays free of FP8-only rules.
# ---- Pass 1: top-level + structural prefix renames ----
WeightRenaming(source_patterns=r"^embed\.weight$", target_patterns="model.embed_tokens.weight"),
WeightRenaming(source_patterns=r"^head\.weight$", target_patterns="lm_head.weight"),
WeightRenaming(source_patterns=r"^norm\.weight$", target_patterns="model.norm.weight"),
WeightRenaming(source_patterns=r"^hc_head_fn$", target_patterns="model.hc_head.hc_fn"),
WeightRenaming(source_patterns=r"^hc_head_base$", target_patterns="model.hc_head.hc_base"),
WeightRenaming(source_patterns=r"^hc_head_scale$", target_patterns="model.hc_head.hc_scale"),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.attn_norm\.",
target_patterns=r"model.layers.\1.input_layernorm.",
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.ffn_norm\.",
target_patterns=r"model.layers.\1.post_attention_layernorm.",
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.hc_attn_fn$", target_patterns=r"model.layers.\1.attn_hc.fn"
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.hc_attn_base$", target_patterns=r"model.layers.\1.attn_hc.base"
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.hc_attn_scale$", target_patterns=r"model.layers.\1.attn_hc.scale"
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.hc_ffn_fn$", target_patterns=r"model.layers.\1.ffn_hc.fn"
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.hc_ffn_base$", target_patterns=r"model.layers.\1.ffn_hc.base"
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.hc_ffn_scale$", target_patterns=r"model.layers.\1.ffn_hc.scale"
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.attn\.",
target_patterns=r"model.layers.\1.self_attn.",
),
WeightRenaming(
source_patterns=r"^layers\.(\d+)\.ffn\.",
target_patterns=r"model.layers.\1.mlp.",
),
# ---- Pass 2: in-prefix specific renames (operate on already-prefixed keys) ----
# These can safely run after the structural prefix renames because their
# source patterns include the ``model.layers.X.self_attn.`` / ``model.layers.X.mlp.``
# prefix. On reverse the order flips so these undo first, restoring the
# specific upstream names *before* the structural rules strip the prefix.
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.self_attn\.attn_sink$",
target_patterns=r"model.layers.\1.self_attn.sinks",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.self_attn\.indexer\.compressor\.norm\.",
target_patterns=r"model.layers.\1.self_attn.compressor.indexer.kv_norm.",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.self_attn\.indexer\.compressor\.ape$",
target_patterns=r"model.layers.\1.self_attn.compressor.indexer.position_bias",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.self_attn\.indexer\.compressor\.",
target_patterns=r"model.layers.\1.self_attn.compressor.indexer.",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.self_attn\.indexer\.",
target_patterns=r"model.layers.\1.self_attn.compressor.indexer.",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.self_attn\.compressor\.norm\.",
target_patterns=r"model.layers.\1.self_attn.compressor.kv_norm.",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.self_attn\.compressor\.ape$",
target_patterns=r"model.layers.\1.self_attn.compressor.position_bias",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.mlp\.shared_experts\.w1\.",
target_patterns=r"model.layers.\1.mlp.shared_experts.gate_proj.",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.mlp\.shared_experts\.w2\.",
target_patterns=r"model.layers.\1.mlp.shared_experts.down_proj.",
),
WeightRenaming(
source_patterns=r"^model\.layers\.(\d+)\.mlp\.shared_experts\.w3\.",
target_patterns=r"model.layers.\1.mlp.shared_experts.up_proj.",
),
WeightConverter(
source_patterns=[
"experts.*.w1.weight",
"experts.*.w3.weight",
],
target_patterns="experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_patterns="experts.*.w2.weight",
target_patterns="experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"llava": [
WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
Expand Down Expand Up @@ -687,8 +809,11 @@ def get_model_conversion_mapping(
if add_legacy:
weight_conversions.extend(get_checkpoint_conversion_mapping("legacy"))

# Add the ones from the quantizer as well if provided
# Let the quantizer rewrite / augment the conversion pipeline. This is where the
# FP8 dequantizer (when ``dequantize=True``) prepends a ``Fp8Dequantize`` op to
# every existing converter so that per-block scales are applied *before* any
# expert-merge / concat ops flatten the per-expert structure away.
if hf_quantizer is not None:
weight_conversions.extend(hf_quantizer.get_weight_conversions())
weight_conversions = hf_quantizer.update_weight_conversions(weight_conversions)

return weight_conversions
5 changes: 4 additions & 1 deletion src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,11 @@ def convert(
target_pattern = self.get_target_pattern(target_patterns)
all_tensors = []
# Very important to keep the relative order of the source patterns here, so we iterate over them not the
# input directly as it's unordered!
# input directly as it's unordered! Skip patterns that prior ops in the chain (e.g. ``Fp8Dequantize``)
# have already consumed and dropped from ``input_dict``.
for source_pattern in source_patterns:
if source_pattern not in input_dict:
continue
tensors = input_dict[source_pattern]
if isinstance(tensors, list):
all_tensors.extend(tensors)
Expand Down
Loading
Loading