Skip to content

Add DeepSeek V4#45643

Draft
ArthurZucker wants to merge 3 commits intomainfrom
add-deepseek-v4
Draft

Add DeepSeek V4#45643
ArthurZucker wants to merge 3 commits intomainfrom
add-deepseek-v4

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Draft. Supersedes #45616.

@ArthurZucker ArthurZucker mentioned this pull request Apr 25, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker
Copy link
Copy Markdown
Collaborator Author

Outputs are valid now

penguinwu added a commit to penguinwu/oss-model-graph-break-corpus that referenced this pull request Apr 25, 2026
… Phase 1 config + runner

Three coupled changes:

1) discovery/perf.py — harden per Rocky's notes (2026-04-25) on
   pytorch/benchmarks/dynamo/common.py:
   - patch_torch_manual_seed(seed=1337) — call once at process start;
     monkey-patches torch.manual_seed so HF models' internal RNG calls
     don't drift between runs (per Animesh on HF model non-determinism).
   - eager_self_check — runs forward twice with cloned identical inputs;
     reports max_abs_diff + deterministic bool. Detects models still
     non-deterministic even with the seed patch.
   - warm_peak_mem flag — captures both cold (default) and post-warmup
     peak memory. Don't conflate the two.
   - compile_times — captures torch._dynamo.utils.compile_times() dict
     (22+ metrics: _compile.compile_inner, GraphLowering.run, etc.) for
     cross-comparable compile-time analysis vs upstream HF dashboard.
   - methodology comments updated to reference common.py line numbers.

2) experiments/configs/deepseek-v4-pro-phase1.json — config for the
   Phase 1 eval. Scaled-but-architecturally-complete: ALL V4 features
   active at production dims (head_dim=512, q_lora_rank=1536, num_hash_
   layers=3, index_n_heads=64, hc_mult=4, hybrid attention, MLA, etc.);
   only num_hidden_layers (61->4), n_routed_experts (384->16), and
   vocab_size (129280->4096) scaled to fit 1x H100 in bf16. Pins the
   transformers PR branch sha (huggingface/transformers#45643 @ a0a8482).

3) experiments/scripts/run_deepseek_v4_pro_phase1.py — self-contained
   runner. Reads the config, applies seed patch + TF32 high precision,
   instantiates the model, and runs 4 dimensions in sequence:
     Step 1: instantiate + eager forward (param count, peak mem)
     Step 2: torch._dynamo.explain (graph break analysis)
     Step 3: correctness vs eager (max_abs_diff + bitwise_equal)
     Step 4: tier-1 perf via measure_perf (eager_ms / compiled_ms /
             speedup / compile_s + compile_times breakdown)
   Writes per-row results to experiments/results/deepseek_v4_pro/
   phase1-tiny-<datestamp>/results.json. Top-level torch.compile.

Phase 1 eval not yet executed — runner is ready; smoke-tested perf.py
upgrade. See experiments/deepseek_v4_pro_eval_plan.md.
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/cache_utils.py Outdated
Comment thread src/transformers/models/deepseek_v4/modeling_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modeling_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modeling_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modeling_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Comment thread src/transformers/models/deepseek_v4/modular_deepseek_v4.py Outdated
Adds DeepSeek V4 with hybrid CSA/HCA attention, lightning indexer,
manifold-constrained hyper-connections, shared K=V MQA with grouped
low-rank output, and per-head attention sink. Includes tokenizer/auto
mappings, finegrained FP8 quantization support, and unit tests.
ArthurZucker and others added 2 commits April 28, 2026 19:27
No inheritance between HCA and CSA: each has its own cache (DynamicSlidingWindowLayer
subclass) and compressor (nn.Module subclass). HCA stays minimal (non-overlapping
windows, no indexer); CSA explicitly carries the overlap state + indexer. Shared
math factored into module-level helpers — no coff/overlap branching, no
_compress_rate_attr indirection. Also adds 'sliding_attention' to COMPRESSOR_CLASSES
with None so the three attention types are dispatched explicitly in one place.
Generation tests were assuming V4 supports advanced decoding modes (assisted
generation, prompt lookup, contrastive search, static-cache compile) that the
compressor's running-window cache state can't service — its buffer / pool /
overlap fields aren't rewindable across drafts and aren't compatible with
:class:`StaticCache`. Set the right opt-out flags so generate raises a clear
error early and the corresponding tests skip cleanly:

* ``_is_stateful = True``      — gates assisted / prompt-lookup paths.
* ``_can_compile_fullgraph = False`` — gates the static-cache test (would
  otherwise hand the compressor a :class:`StaticSlidingWindowLayer` with no
  ``update_compressor`` method).
* ``_supports_flex_attn = False`` — V4 only validates eager attention; the
  compressor / indexer paths weren't checked under flex / SDPA / flash kernels.

Conversion mapping cleanup so save / load round-trips survive:

* Standardize on V3's ``apply_rotary_pos_emb_interleave`` for the partial-RoPE
  rotation, with a thin V4-side wrapper that permutes the rope channels back
  from the halves layout V3 leaves them in to the interleaved layout V4 was
  trained with — required because V4 is shared-KV (V == K rotated), so V's
  channel layout flows through ``wo_a`` / ``wo_b``.
* Restructure ``conversion_mapping.deepseek_v4`` into two passes: structural
  prefix renames first (``layers.X.attn.`` → ``model.layers.X.self_attn.``),
  then specific in-prefix renames on the already-prefixed HF-form keys
  (``...self_attn.compressor.norm.`` → ``...self_attn.compressor.kv_norm.``).
  A single-pass ordering loses information in either the forward or reverse
  direction (overlapping general / specific patterns conflict).
* Move the FP8 ``.scale`` → ``.weight_scale_inv`` rename out of the V4 static
  conversion list and into ``FineGrainedFP8HfQuantizer.update_weight_conversions``
  so the rule is only registered when FP8 dequant is actually active. Lets
  ``test_reverse_loading_mapping`` skip an unrelated FP8 rule on plain saves.

Test fixes:

* Skip ``test_reverse_loading_mapping`` with a docstring spelling out why the
  two-pass mapping can't satisfy that test's invariant (its Pass 2 source
  patterns are HF-form by design; ``test_save_load`` exercises the actual
  round-trip).
* Skip ``test_left_padding_compatibility`` — V4's compressor pre-pools
  ``compress_rate``-token windows before the attention mask is applied, so
  left padding shifts window boundaries and folds pad tokens into pooled
  KV entries (same fundamental limit as RecurrentGemma).
* Add ``model.to(torch_device)`` in the ``test_hidden_states_output`` override
  so cuda inputs don't hit a cpu model.
* ``test_tiny_generate_runs`` now passes ``eos_token_id=-1`` so a freshly
  initialised random model doesn't EOS-stop before max_new_tokens, making the
  shape assertion deterministic.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, deepseek_v4, finegrained_fp8

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45643&sha=b4b3a2

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I went into details this time, imo the RoPE is messy atm I'm pretty sure it can be refactored into a more normal style

# E2M1 (FP4) value table — checkpoints sometimes ship MoE experts as packed FP4
# (two e2m1 nibbles per int8 byte), so the "weight" dtype lands as ``int8`` /
# ``float4_e2m1fn_x2`` and we have to unpack before applying the scale grid.
_FP4_E2M1_LUT = (0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a bit awkward ngl - guess they did have to make a workaround for that. Only blackwell has native fp4 support iirc

from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP


def apply_rotary_pos_emb(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we should refactor this to a more gemma like rope where we only apply it on one tensor - most of the time we apply to only q which makes a lot of unnecessary computations k then

half = cos.shape[-1] // 2
cos = cos[..., :half].unsqueeze(unsqueeze_dim)
sin = sin[..., :half].unsqueeze(unsqueeze_dim)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just overwrite the RoPE forward to not duplicate instead, i.e.

emb = torch.cat((freqs, freqs), dim=-1)
is removed

Comment on lines +70 to +77
def _rotate(x: torch.Tensor) -> torch.Tensor:
# ``unflatten`` gives `[..., rope_dim/2, 2]` so axis -2 indexes pairs and -1
# indexes (real, imag). Promoting to fp32 matches the reference's precision.
pairs = x.float().unflatten(-1, (-1, 2))
x_re, x_im = pairs[..., 0], pairs[..., 1]
rot_re = x_re * cos - x_im * sin
rot_im = x_re * sin + x_im * cos
return torch.stack([rot_re, rot_im], dim=-1).flatten(-2).to(x.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not exactly

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)

Imo we should refactor to use our normal patterns here then

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float can be incorporated like here then

q_embed = (q.float() * cos) + (rotate_half(q).float() * sin)
k_embed = (k.float() * cos) + (rotate_half(k).float() * sin)

Comment on lines +117 to +121
rope_theta (`float`): RoPE base for the main self-attention rotary.
compress_rope_theta (`float`): RoPE base for the compressed branches (paired with
``rope_scaling`` for YaRN).
partial_rotary_factor (`float`, *optional*): Fraction of head_dim that gets RoPE.
Defaults to ``qk_rope_head_dim / head_dim`` so cos/sin sizes to ``qk_rope_head_dim``.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a layer types dict like in gemma3 / gemma4 so we have proper defaults etc

We can still use BC behavior in the post init to check for kwargs to construct defaults

# SDPA / FlashAttention / FlexAttention kernels — leaving these ``False`` makes
# ``set_attn_implementation`` reject those backends instead of silently routing
# through them.
_supports_flash_attn = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, ok that explains my earlier comment about the padded mask

Comment on lines +1462 to +1467
def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value

shouldnt be needed



@auto_docstring
class DeepseekV4Model(DeepseekV4PreTrainedModel):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could still inherit from e.g. llama tbh, and add the extra stuff onto the init

position_ids=position_ids,
)
hidden_states = inputs_embeds.unsqueeze(2).expand(-1, -1, self.config.hc_mult, -1).contiguous()
cos_sin = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cos_sin = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main")
position_embeddings = self.rotary_emb(inputs_embeds, position_ids=position_ids, layer_type="main")

Comment on lines +1530 to +1535
class DeepseekV4ForCausalLM(MixtralForCausalLM):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}

def __init__(self, config: DeepseekV4Config):
super().__init__(config)
self.model = DeepseekV4Model(config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class DeepseekV4ForCausalLM(MixtralForCausalLM):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config: DeepseekV4Config):
super().__init__(config)
self.model = DeepseekV4Model(config)
class DeepseekV4ForCausalLM(MixtralForCausalLM):
pass

no?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants