Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
777c272
honor the quant config's scale format and refuse
IlyasMoutawwakil Jun 22, 2026
122adb5
fix fp4 specific
IlyasMoutawwakil Jun 22, 2026
98034c9
strict
IlyasMoutawwakil Jun 22, 2026
f1e2235
deeper ep fix
IlyasMoutawwakil Jun 22, 2026
83625bc
test
IlyasMoutawwakil Jun 22, 2026
b34872e
style
IlyasMoutawwakil Jun 22, 2026
0e0550a
add assertion
IlyasMoutawwakil Jun 23, 2026
4a585fe
more ep plans
IlyasMoutawwakil Jun 23, 2026
0288a11
fold tp+ep checks and ep assert into one helper
IlyasMoutawwakil Jun 23, 2026
7d5976d
style
IlyasMoutawwakil Jun 23, 2026
8ea20bf
rasie propper error upon ep request with no ep plan
IlyasMoutawwakil Jun 23, 2026
d288a9e
Merge branch 'main' into fix-glm-dsa
IlyasMoutawwakil Jun 23, 2026
d7a2fea
address anton's comments and make more modular
IlyasMoutawwakil Jun 23, 2026
99f28f1
fix repo
IlyasMoutawwakil Jun 23, 2026
b8f8eed
more modular
IlyasMoutawwakil Jun 23, 2026
4b04deb
more modular dsv2 topK router
IlyasMoutawwakil Jun 24, 2026
c015bd2
modular phimoe router
IlyasMoutawwakil Jun 24, 2026
6cf5856
fix
IlyasMoutawwakil Jun 24, 2026
c837936
reverting phimoe changes
IlyasMoutawwakil Jun 24, 2026
9cd6feb
last modular attempt
IlyasMoutawwakil Jun 24, 2026
31f2139
correct fix ?
IlyasMoutawwakil Jun 24, 2026
5c2b2af
add BC variation just in case
vasqu Jun 24, 2026
918d6c4
clearer message
vasqu Jun 24, 2026
c54aee9
post init workaround?
vasqu Jun 24, 2026
6de8bff
remove the warning
vasqu Jun 24, 2026
d1dcde0
Merge branch 'main' into fix-glm-dsa
vasqu Jun 24, 2026
54ab3ee
fix CI
vasqu Jun 24, 2026
8d9846b
ci
vasqu Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/transformers/integrations/deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,28 @@ def _is_sm100(device: torch.device) -> bool:
return torch.cuda.get_device_capability(device)[0] >= 10


def _assert_sm100_scales_are_ue8m0(scale: torch.Tensor) -> None:
"""On B200 (SM100) DeepGEMM only supports UE8M0 (power-of-two) scales; the float32 scales
that work on H100 (SM90) have no SM100 path. UE8M0 scales load as ``float8_e8m0fnu`` (the
loader normalizes even float32-container checkpoints like dsv4-flash-base), so a plain
``float32`` scale here means a genuine non-UE8M0 checkpoint — fail loud rather than let
``_coerce_sf_for_kernel`` silently round it and corrupt the output.
"""
if not _is_sm100(scale.device):
return # SM90 consumes float32 SFs directly (no UE8M0 round).
if scale.dtype != torch.float32:
return # already UE8M0 (`float8_e8m0fnu`) — kernel-ready as-is.
raise ValueError(
"DeepGEMM's Blackwell (SM100) experts kernel requires power-of-two (UE8M0) scale "
"factors, but this checkpoint's expert scales are plain float32 "
"(quantization_config.scale_fmt='float'). Rounding them to UE8M0 would scale the "
"dequantized expert weights incorrectly and silently corrupt the output. Use a "
"checkpoint quantized with scale_fmt='ue8m0', or an experts implementation that "
"consumes float32 block scales directly, e.g. "
"`model.set_experts_implementation('grouped_mm')`."
)


_DEEPGEMM_VISITED_DEVICES: set[int] = set()


Expand Down Expand Up @@ -556,6 +578,7 @@ def deepgemm_fp8_fp4_experts_forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
_assert_single_device(hidden_states.device, context="experts")
_assert_sm100_scales_are_ue8m0(self.down_proj_scale_inv)

if self.activation_scheme == "static":
raise NotImplementedError("DeepGEMM experts dispatch does not support static activation quantization.")
Expand Down Expand Up @@ -715,6 +738,8 @@ def deepgemm_fp8_fp4_megamoe_experts_forward(
`transform_weights_for_mega_moe((gate_up, gate_up_sf), (down, down_sf))`.
- `config.swiglu_limit` (optional): SwiGLU clamp; absent → unclamped.
"""
_assert_sm100_scales_are_ue8m0(self.down_proj_scale_inv)

if self.gate_up_proj.dtype != torch.int8:
raise RuntimeError(
f"DeepGEMM Mega MoE requires FP4-packed expert weights (dtype=`int8`), got "
Expand Down
41 changes: 11 additions & 30 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@

@functools.cache
def _get_ue8m0_dtype() -> torch.dtype:
"""Return ``torch.float8_e8m0fnu`` or raise a clear error on torch without FP8 support."""
"""Return ``torch.float8_e8m0fnu`` or raise a clear error on torch without FP8 support.

UE8M0 scales are always stored/consumed as this single dtype — the kernels (Triton
finegrained + DeepGEMM) read it natively, and supporting the same scales in mixed
container dtypes would be a mess — so fail loudly rather than fall back."""
if not hasattr(torch, "float8_e8m0fnu"):
raise RuntimeError(
"scale_fmt='ue8m0' requires torch.float8_e8m0fnu, which is only available in "
Expand Down Expand Up @@ -294,7 +298,10 @@ def __init__(
sf_dtype = _get_ue8m0_dtype() if scale_fmt == "ue8m0" else torch.float32
scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=sf_dtype))
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=sf_dtype),
requires_grad=sf_dtype.is_floating_point,
)

if self.activation_scheme == "static":
self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
Expand Down Expand Up @@ -871,8 +878,8 @@ def _quantize_one(self, key: str, value: torch.Tensor) -> dict[str, torch.Tensor
quantized = quantized.reshape(original_shape)
inv_scales = (1.0 / scales).to(torch.float32)
# DeepSeek V4-style storage (`scale_fmt="ue8m0"`): round inv_scales to UE8M0-representable
# values (powers of 2) and cast to `float8_e8m0fnu` byte storage so the on-disk dtype
# matches the parameter allocation in `FP8Linear`/`FP8Experts`.
# values (powers of 2) and cast to the UE8M0 byte storage so the on-disk dtype matches the
# parameter allocation in `FP8Linear`/`FP8Experts`.
if self.hf_quantizer.quantization_config.scale_fmt == "ue8m0":
inv_scales = torch.pow(2.0, torch.ceil(torch.log2(inv_scales.clamp(min=torch.finfo(torch.float32).tiny))))
inv_scales = inv_scales.to(_get_ue8m0_dtype())
Expand Down Expand Up @@ -1045,29 +1052,3 @@ def reverse_op(self) -> ConversionOps:
# checkpoint preserves the FP8 format (weight + per-block ``weight_scale_inv``)
# whether the in-memory state stayed quantized or was dequantized for compute.
return Fp8Quantize(self.hf_quantizer)


class Fp8DecodeScale(ConversionOps):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any ideas as to why this part was dropped ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because i added support for ue8m0 scales in finegrained-fp8 v3, this was needed for minimax m3 with the v2, but not anymore, it also wastes memory

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ue8m0 scales are a bit messy, some store them in the correct torch dtype, some store them in uint8, and some even store them in fp32 for no special reason 😭 i'm trying to tighten the contract and honor the config all the times because supporting all the on-disk variations would be more complicated

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay ! Just to be sure, if we remove it now, it would not break existing checkpoints that are in mxpf8 format right ?

@IlyasMoutawwakil IlyasMoutawwakil Jun 23, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no they will work fine, even better because I just noticed that the fp32 scales are even avoiding the optimized mxfp8 path in https://github.com/huggingface/kernels-community/blob/aeb8ef0e09a132a6583c0a4c8b1096292922b54a/finegrained-fp8/torch-ext/finegrained_fp8/utils.py#L64 I also ran minimax m3 integration tests on the b200

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep sounds good, just require the version of the kernel for that path to error out properly if kernel version not installed

@IlyasMoutawwakil IlyasMoutawwakil Jun 24, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do pin the v3 in our lazy loading

"""Decode MXFP8 ``ue8m0`` per-block scales (stored as ``uint8`` exponents) into the
float32 multiplicative scales the FP8 compute path expects.

Native MXFP8 loading (``dequantize=False``) keeps weights in ``float8_e4m3fn`` and only
needs the sibling ``*.weight_scale_inv`` tensors turned from raw E8M0 bytes into real
scales (``2 ** (byte - 127)``). Prepended to each weight converter, this op runs before
any merge/concat collapses the per-expert structure: it rewrites only the ``uint8`` scale
entries and passes weights (and already-float scales) through untouched.
"""

def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

@staticmethod
def _decode(tensor: torch.Tensor) -> torch.Tensor:
# E8M0 stores one exponent byte per block; the real scale is ``2 ** (byte - 127)``.
return (tensor.to(torch.float32) - 127.0).exp2() if tensor.dtype == torch.uint8 else tensor

def convert(self, input_dict: dict[str, list[torch.Tensor] | torch.Tensor], **kwargs):
return {
key: [self._decode(t) for t in value] if isinstance(value, list) else self._decode(value)
for key, value in input_dict.items()
}
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,12 @@ def tp_plan(self) -> dict[str, str]:
The full tp plan for the model's modules
"""
if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
if not self._ep_plan:
raise ValueError(
f"Expert parallelism was requested (`enable_expert_parallel=True`), but "
f"`{self.__class__.__name__}` does not define an expert-parallel plan. Add a "
f"`base_model_ep_plan` to its config, or disable expert parallelism."
)
Comment on lines +1470 to +1475

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loud failure on missing ep plan

return self._ep_plan
return self._tp_plan

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/afmoe/configuration_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class AfmoeConfig(PreTrainedConfig):
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_model_ep_plan = {
"layers.*.mlp.router": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
}

vocab_size: int = 200192
hidden_size: int = 2048
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def forward(self, hidden_states):
hidden_states_flat = hidden_states.view(-1, hidden_dim)

# Get routing decisions (returns flattened top-k)
router_logits, top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
_, top_scores, selected_experts = self.router(hidden_states, self.expert_bias)

# Process through shared experts
shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/afmoe/modular_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(self, hidden_states):
hidden_states_flat = hidden_states.view(-1, hidden_dim)

# Get routing decisions (returns flattened top-k)
router_logits, top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
_, top_scores, selected_experts = self.router(hidden_states, self.expert_bias)

# Process through shared experts
shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ class Cohere2MoeConfig(PreTrainedConfig):
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_model_ep_plan = {
"layers.*.mlp.gate": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
}

vocab_size: int = 256000
hidden_size: int = 8192
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Cohere2MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.expert_selection_fn = config.expert_selection_fn
self.norm_topk_prob = config.norm_topk_prob
self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size))
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere2_moe/modular_cohere2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Cohere2MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.expert_selection_fn = config.expert_selection_fn
self.norm_topk_prob = config.norm_topk_prob
self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,23 @@ class DeepseekOcr2TextConfig(PreTrainedConfig):
attention_dropout: float | None = 0.0
mlp_bias: bool = False
head_dim: int | None = None
base_model_ep_plan = {
"layers.*.mlp.gate": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
}
attribute_map = {
"num_local_experts": "n_routed_experts",
"num_experts": "n_routed_experts",
Comment thread
IlyasMoutawwakil marked this conversation as resolved.
Outdated
}
n_group: int | None = None
n_routed_experts: int = 64
n_shared_experts: int = 2
routed_scaling_factor: float = 1.0
topk_group: int | None = None
topk_method: str | None = "greedy"
norm_topk_prob: bool | None = False
num_experts_per_tok: int | None = None
moe_intermediate_size: int = 1407

Expand Down
57 changes: 33 additions & 24 deletions src/transformers/models/deepseek_ocr2/modeling_deepseek_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ class DeepseekOcr2TextExperts(nn.Module):

def __init__(self, config):
super().__init__()
self.num_experts = config.n_routed_experts
self.num_experts = config.num_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
Expand Down Expand Up @@ -1194,48 +1194,55 @@ def forward(
return final_hidden_states


class DeepseekOcr2TextMoe(nn.Module):
class DeepseekOcr2TextTopkRouter(nn.Module):
def __init__(self, config: DeepseekOcr2TextConfig):
super().__init__()
self.config = config
self.experts = DeepseekOcr2TextExperts(config)
self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size)
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.norm_topk_prob = config.norm_topk_prob
self.hidden_dim = config.hidden_size
self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.num_group = config.n_group
self.top_k = config.num_experts_per_tok
self.topk_group = config.topk_group

def route_tokens_to_experts(self, router_logits):
batch_size, seq_len, hidden_dim = router_logits.shape
router_logits = router_logits.view(-1, hidden_dim)
router_logits = router_logits.softmax(dim=-1, dtype=torch.float32)
def forward(self, hidden_states):
hidden_states = hidden_states.view(-1, self.hidden_dim)
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
scores = router_logits.softmax(dim=-1, dtype=torch.float32)
if self.topk_method == "greedy":
topk_weight, topk_idx = torch.topk(router_logits, k=self.top_k, dim=-1, sorted=False)
topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
elif self.topk_method == "group_limited_greedy":
group_scores = router_logits.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values
group_scores = scores.view(-1, self.num_group, self.num_experts // self.num_group).max(dim=-1).values
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group)
.reshape(batch_size * seq_len, -1)
.expand(-1, self.num_group, self.num_experts // self.num_group)
.reshape(-1, self.num_experts)
)
tmp_scores = router_logits.masked_fill(~score_mask.bool(), 0.0)
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
scores = scores.masked_fill(~score_mask.bool(), 0.0)
topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
topk_weights = topk_weights * self.routed_scaling_factor
return router_logits, topk_weights, topk_indices

topk_weight = topk_weight * self.routed_scaling_factor
return topk_idx, topk_weight

class DeepseekOcr2TextMoe(nn.Module):
def __init__(self, config: DeepseekOcr2TextConfig):
super().__init__()
self.config = config
self.experts = DeepseekOcr2TextExperts(config)
self.gate = DeepseekOcr2TextTopkRouter(config)
self.shared_experts = DeepseekOcr2TextMLP(
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
orig_shape = hidden_states.shape
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
_, topk_weights, topk_indices = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
Expand Down Expand Up @@ -1330,7 +1337,9 @@ class DeepseekOcr2TextPreTrainedModel(PreTrainedModel):
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, DeepseekOcr2TextExperts):
if isinstance(module, DeepseekOcr2TextTopkRouter):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
elif isinstance(module, DeepseekOcr2TextExperts):
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,6 @@ class DeepseekOcr2TextConfig(DeepseekV2Config):
# Remove unused attributes inherited from DeepseekV2Config
first_k_dense_replace = AttributeError()
kv_lora_rank = AttributeError()
norm_topk_prob = AttributeError()
Comment thread
vasqu marked this conversation as resolved.
q_lora_rank = AttributeError()
qk_nope_head_dim = AttributeError()
qk_rope_head_dim = AttributeError()
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/deepseek_v2/configuration_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ class DeepseekV2Config(PreTrainedConfig):
attention_dropout: float | None = 0.0
mlp_bias: bool = False
head_dim: int | None = None
base_model_ep_plan = {
"layers.*.mlp.gate": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
}
attribute_map = {
"num_local_experts": "n_routed_experts",
"num_experts": "n_routed_experts",
}
first_k_dense_replace: int = 0
kv_lora_rank: int = 512
q_lora_rank: int | None = 1536
Expand Down
Loading
Loading