Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/transformers/integrations/deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,28 @@ def _is_sm100(device: torch.device) -> bool:
return torch.cuda.get_device_capability(device)[0] >= 10


def _assert_sm100_scales_are_ue8m0(scale: torch.Tensor) -> None:
"""On B200 (SM100) DeepGEMM only supports UE8M0 (power-of-two) scales; the float32 scales
that work on H100 (SM90) have no SM100 path. UE8M0 scales load as ``float8_e8m0fnu`` (the
loader normalizes even float32-container checkpoints like dsv4-flash-base), so a plain
``float32`` scale here means a genuine non-UE8M0 checkpoint — fail loud rather than let
``_coerce_sf_for_kernel`` silently round it and corrupt the output.
"""
if not _is_sm100(scale.device):
return # SM90 consumes float32 SFs directly (no UE8M0 round).
if scale.dtype != torch.float32:
return # already UE8M0 (`float8_e8m0fnu`) — kernel-ready as-is.
raise ValueError(
"DeepGEMM's Blackwell (SM100) experts kernel requires power-of-two (UE8M0) scale "
"factors, but this checkpoint's expert scales are plain float32 "
"(quantization_config.scale_fmt='float'). Rounding them to UE8M0 would scale the "
"dequantized expert weights incorrectly and silently corrupt the output. Use a "
"checkpoint quantized with scale_fmt='ue8m0', or an experts implementation that "
"consumes float32 block scales directly, e.g. "
"`model.set_experts_implementation('grouped_mm')`."
)


_DEEPGEMM_VISITED_DEVICES: set[int] = set()


Expand Down Expand Up @@ -556,6 +578,7 @@ def deepgemm_fp8_fp4_experts_forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
_assert_single_device(hidden_states.device, context="experts")
_assert_sm100_scales_are_ue8m0(self.down_proj_scale_inv)

if self.activation_scheme == "static":
raise NotImplementedError("DeepGEMM experts dispatch does not support static activation quantization.")
Expand Down Expand Up @@ -715,6 +738,8 @@ def deepgemm_fp8_fp4_megamoe_experts_forward(
`transform_weights_for_mega_moe((gate_up, gate_up_sf), (down, down_sf))`.
- `config.swiglu_limit` (optional): SwiGLU clamp; absent → unclamped.
"""
_assert_sm100_scales_are_ue8m0(self.down_proj_scale_inv)

if self.gate_up_proj.dtype != torch.int8:
raise RuntimeError(
f"DeepGEMM Mega MoE requires FP4-packed expert weights (dtype=`int8`), got "
Expand Down
41 changes: 11 additions & 30 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@

@functools.cache
def _get_ue8m0_dtype() -> torch.dtype:
"""Return ``torch.float8_e8m0fnu`` or raise a clear error on torch without FP8 support."""
"""Return ``torch.float8_e8m0fnu`` or raise a clear error on torch without FP8 support.

UE8M0 scales are always stored/consumed as this single dtype — the kernels (Triton
finegrained + DeepGEMM) read it natively, and supporting the same scales in mixed
container dtypes would be a mess — so fail loudly rather than fall back."""
if not hasattr(torch, "float8_e8m0fnu"):
raise RuntimeError(
"scale_fmt='ue8m0' requires torch.float8_e8m0fnu, which is only available in "
Expand Down Expand Up @@ -294,7 +298,10 @@ def __init__(
sf_dtype = _get_ue8m0_dtype() if scale_fmt == "ue8m0" else torch.float32
scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=sf_dtype))
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=sf_dtype),
requires_grad=sf_dtype.is_floating_point,
)

if self.activation_scheme == "static":
self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
Expand Down Expand Up @@ -871,8 +878,8 @@ def _quantize_one(self, key: str, value: torch.Tensor) -> dict[str, torch.Tensor
quantized = quantized.reshape(original_shape)
inv_scales = (1.0 / scales).to(torch.float32)
# DeepSeek V4-style storage (`scale_fmt="ue8m0"`): round inv_scales to UE8M0-representable
# values (powers of 2) and cast to `float8_e8m0fnu` byte storage so the on-disk dtype
# matches the parameter allocation in `FP8Linear`/`FP8Experts`.
# values (powers of 2) and cast to the UE8M0 byte storage so the on-disk dtype matches the
# parameter allocation in `FP8Linear`/`FP8Experts`.
if self.hf_quantizer.quantization_config.scale_fmt == "ue8m0":
inv_scales = torch.pow(2.0, torch.ceil(torch.log2(inv_scales.clamp(min=torch.finfo(torch.float32).tiny))))
inv_scales = inv_scales.to(_get_ue8m0_dtype())
Expand Down Expand Up @@ -1045,29 +1052,3 @@ def reverse_op(self) -> ConversionOps:
# checkpoint preserves the FP8 format (weight + per-block ``weight_scale_inv``)
# whether the in-memory state stayed quantized or was dequantized for compute.
return Fp8Quantize(self.hf_quantizer)


class Fp8DecodeScale(ConversionOps):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any ideas as to why this part was dropped ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because i added support for ue8m0 scales in finegrained-fp8 v3, this was needed for minimax m3 with the v2, but not anymore, it also wastes memory

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ue8m0 scales are a bit messy, some store them in the correct torch dtype, some store them in uint8, and some even store them in fp32 for no special reason 😭 i'm trying to tighten the contract and honor the config all the times because supporting all the on-disk variations would be more complicated

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay ! Just to be sure, if we remove it now, it would not break existing checkpoints that are in mxpf8 format right ?

@IlyasMoutawwakil IlyasMoutawwakil Jun 23, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no they will work fine, even better because I just noticed that the fp32 scales are even avoiding the optimized mxfp8 path in https://github.com/huggingface/kernels-community/blob/aeb8ef0e09a132a6583c0a4c8b1096292922b54a/finegrained-fp8/torch-ext/finegrained_fp8/utils.py#L64 I also ran minimax m3 integration tests on the b200

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Decode MXFP8 ``ue8m0`` per-block scales (stored as ``uint8`` exponents) into the
float32 multiplicative scales the FP8 compute path expects.

Native MXFP8 loading (``dequantize=False``) keeps weights in ``float8_e4m3fn`` and only
needs the sibling ``*.weight_scale_inv`` tensors turned from raw E8M0 bytes into real
scales (``2 ** (byte - 127)``). Prepended to each weight converter, this op runs before
any merge/concat collapses the per-expert structure: it rewrites only the ``uint8`` scale
entries and passes weights (and already-float scales) through untouched.
"""

def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

@staticmethod
def _decode(tensor: torch.Tensor) -> torch.Tensor:
# E8M0 stores one exponent byte per block; the real scale is ``2 ** (byte - 127)``.
return (tensor.to(torch.float32) - 127.0).exp2() if tensor.dtype == torch.uint8 else tensor

def convert(self, input_dict: dict[str, list[torch.Tensor] | torch.Tensor], **kwargs):
return {
key: [self._decode(t) for t in value] if isinstance(value, list) else self._decode(value)
for key, value in input_dict.items()
}
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,12 @@ def tp_plan(self) -> dict[str, str]:
The full tp plan for the model's modules
"""
if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
if not self._ep_plan:
raise ValueError(
f"Expert parallelism was requested (`enable_expert_parallel=True`), but "
f"`{self.__class__.__name__}` does not define an expert-parallel plan. Add a "
f"`base_model_ep_plan` to its config, or disable expert parallelism."
)
return self._ep_plan
return self._tp_plan

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/afmoe/configuration_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class AfmoeConfig(PreTrainedConfig):
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_model_ep_plan = {
"layers.*.mlp.router": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
}

vocab_size: int = 200192
hidden_size: int = 2048
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ class Cohere2MoeConfig(PreTrainedConfig):
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_model_ep_plan = {
"layers.*.mlp.gate": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
}

vocab_size: int = 256000
hidden_size: int = 8192
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Cohere2MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.expert_selection_fn = config.expert_selection_fn
self.norm_topk_prob = config.norm_topk_prob
self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size))
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere2_moe/modular_cohere2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Cohere2MoeTopKRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.expert_selection_fn = config.expert_selection_fn
self.norm_topk_prob = config.norm_topk_prob
self.weight = nn.Parameter(torch.empty(config.num_experts, config.hidden_size))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ class DeepseekOcr2TextConfig(PreTrainedConfig):
attention_dropout: float | None = 0.0
mlp_bias: bool = False
head_dim: int | None = None
base_model_ep_plan = {
"layers.*.mlp.gate": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
}
n_group: int | None = None
n_routed_experts: int = 64
n_shared_experts: int = 2
Expand Down
55 changes: 33 additions & 22 deletions src/transformers/models/deepseek_ocr2/modeling_deepseek_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,48 +1194,57 @@ def forward(
return final_hidden_states


class DeepseekOcr2TextMoe(nn.Module):
class DeepseekOcr2TextTopkRouter(nn.Module):
def __init__(self, config: DeepseekOcr2TextConfig):
super().__init__()
self.config = config
self.experts = DeepseekOcr2TextExperts(config)
self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size)
self.num_experts = config.n_routed_experts

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundancy in variables ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeh i guess we can drop n_routed_experts, removing it

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm so it seems to cascade into many models

self.top_k = config.num_experts_per_tok
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.num_group = config.n_group
self.top_k = config.num_experts_per_tok
self.topk_group = config.topk_group
# Named `weight` (not wrapped in `nn.Linear`) so the parameter key stays `mlp.gate.weight`.
self.weight = nn.Parameter(torch.empty((self.num_experts, config.hidden_size)))

def route_tokens_to_experts(self, router_logits):
batch_size, seq_len, hidden_dim = router_logits.shape
router_logits = router_logits.view(-1, hidden_dim)
router_logits = router_logits.softmax(dim=-1, dtype=torch.float32)
def forward(self, hidden_states):
# Top-k selection lives in the router (not the MoE block) so the `ep_router`
# tensor-parallel hook can remap global → local expert ids on the returned indices.
hidden_states = hidden_states.view(-1, self.config.hidden_size)
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
scores = router_logits.softmax(dim=-1, dtype=torch.float32)
if self.topk_method == "greedy":
topk_weight, topk_idx = torch.topk(router_logits, k=self.top_k, dim=-1, sorted=False)
topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
elif self.topk_method == "group_limited_greedy":
group_scores = router_logits.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values
group_scores = scores.view(-1, self.num_group, self.num_experts // self.num_group).max(dim=-1).values
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group)
.reshape(batch_size * seq_len, -1)
.expand(-1, self.num_group, self.num_experts // self.num_group)
.reshape(-1, self.num_experts)
)
tmp_scores = router_logits.masked_fill(~score_mask.bool(), 0.0)
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
scores = scores.masked_fill(~score_mask.bool(), 0.0)
topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
topk_weights = topk_weights * self.routed_scaling_factor
return router_logits, topk_weights, topk_indices

topk_weight = topk_weight * self.routed_scaling_factor
return topk_idx, topk_weight

class DeepseekOcr2TextMoe(nn.Module):
def __init__(self, config: DeepseekOcr2TextConfig):
super().__init__()
self.config = config
self.experts = DeepseekOcr2TextExperts(config)
self.gate = DeepseekOcr2TextTopkRouter(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
orig_shape = hidden_states.shape
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
_, topk_weights, topk_indices = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
Expand Down Expand Up @@ -1330,7 +1339,9 @@ class DeepseekOcr2TextPreTrainedModel(PreTrainedModel):
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, DeepseekOcr2TextExperts):
if isinstance(module, DeepseekOcr2TextTopkRouter):
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
elif isinstance(module, DeepseekOcr2TextExperts):
init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range)
init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ class DeepseekV2Config(PreTrainedConfig):
attention_dropout: float | None = 0.0
mlp_bias: bool = False
head_dim: int | None = None
base_model_ep_plan = {
"layers.*.mlp.gate": "ep_router",
"layers.*.mlp.experts.gate_up_proj": "grouped_gemm",
"layers.*.mlp.experts.down_proj": "grouped_gemm",
"layers.*.mlp.experts": "moe_tp_experts",
}
first_k_dense_replace: int = 0
kv_lora_rank: int = 512
q_lora_rank: int | None = 1536
Expand Down
Loading
Loading