-
Notifications
You must be signed in to change notification settings - Fork 33.6k
🚨 EP: fix EP router contract for many models + honor FP8 scale format #46818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
777c272
122adb5
98034c9
f1e2235
83625bc
b34872e
0e0550a
4a585fe
0288a11
7d5976d
8ea20bf
d288a9e
d7a2fea
99f28f1
b8f8eed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1194,48 +1194,57 @@ def forward( | |
| return final_hidden_states | ||
|
|
||
|
|
||
| class DeepseekOcr2TextMoe(nn.Module): | ||
| class DeepseekOcr2TextTopkRouter(nn.Module): | ||
| def __init__(self, config: DeepseekOcr2TextConfig): | ||
| super().__init__() | ||
| self.config = config | ||
| self.experts = DeepseekOcr2TextExperts(config) | ||
| self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False) | ||
| if config.n_shared_experts is not None: | ||
| intermediate_size = config.moe_intermediate_size * config.n_shared_experts | ||
| self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size) | ||
| self.num_experts = config.n_routed_experts | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundancy in variables ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeh i guess we can drop n_routed_experts, removing it
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm so it seems to cascade into many models |
||
| self.top_k = config.num_experts_per_tok | ||
| self.routed_scaling_factor = config.routed_scaling_factor | ||
| self.topk_method = config.topk_method | ||
| self.num_group = config.n_group | ||
| self.top_k = config.num_experts_per_tok | ||
| self.topk_group = config.topk_group | ||
| # Named `weight` (not wrapped in `nn.Linear`) so the parameter key stays `mlp.gate.weight`. | ||
| self.weight = nn.Parameter(torch.empty((self.num_experts, config.hidden_size))) | ||
|
|
||
| def route_tokens_to_experts(self, router_logits): | ||
| batch_size, seq_len, hidden_dim = router_logits.shape | ||
| router_logits = router_logits.view(-1, hidden_dim) | ||
| router_logits = router_logits.softmax(dim=-1, dtype=torch.float32) | ||
| def forward(self, hidden_states): | ||
| # Top-k selection lives in the router (not the MoE block) so the `ep_router` | ||
| # tensor-parallel hook can remap global → local expert ids on the returned indices. | ||
| hidden_states = hidden_states.view(-1, self.config.hidden_size) | ||
| router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) | ||
| scores = router_logits.softmax(dim=-1, dtype=torch.float32) | ||
| if self.topk_method == "greedy": | ||
| topk_weight, topk_idx = torch.topk(router_logits, k=self.top_k, dim=-1, sorted=False) | ||
| topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) | ||
| elif self.topk_method == "group_limited_greedy": | ||
| group_scores = router_logits.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values | ||
| group_scores = scores.view(-1, self.num_group, self.num_experts // self.num_group).max(dim=-1).values | ||
| group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] | ||
| group_mask = torch.zeros_like(group_scores) | ||
| group_mask.scatter_(1, group_idx, 1) | ||
| score_mask = ( | ||
| group_mask.unsqueeze(-1) | ||
| .expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group) | ||
| .reshape(batch_size * seq_len, -1) | ||
| .expand(-1, self.num_group, self.num_experts // self.num_group) | ||
| .reshape(-1, self.num_experts) | ||
| ) | ||
| tmp_scores = router_logits.masked_fill(~score_mask.bool(), 0.0) | ||
| topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) | ||
| scores = scores.masked_fill(~score_mask.bool(), 0.0) | ||
| topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) | ||
| topk_weights = topk_weights * self.routed_scaling_factor | ||
| return router_logits, topk_weights, topk_indices | ||
|
|
||
| topk_weight = topk_weight * self.routed_scaling_factor | ||
| return topk_idx, topk_weight | ||
|
|
||
| class DeepseekOcr2TextMoe(nn.Module): | ||
| def __init__(self, config: DeepseekOcr2TextConfig): | ||
| super().__init__() | ||
| self.config = config | ||
| self.experts = DeepseekOcr2TextExperts(config) | ||
| self.gate = DeepseekOcr2TextTopkRouter(config) | ||
| if config.n_shared_experts is not None: | ||
| intermediate_size = config.moe_intermediate_size * config.n_shared_experts | ||
| self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| residuals = hidden_states | ||
| orig_shape = hidden_states.shape | ||
| router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32)) | ||
| topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) | ||
| _, topk_weights, topk_indices = self.gate(hidden_states) | ||
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
| hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) | ||
| hidden_states = hidden_states + self.shared_experts(residuals) | ||
|
|
@@ -1330,7 +1339,9 @@ class DeepseekOcr2TextPreTrainedModel(PreTrainedModel): | |
| @torch.no_grad() | ||
| def _init_weights(self, module): | ||
| super()._init_weights(module) | ||
| if isinstance(module, DeepseekOcr2TextExperts): | ||
| if isinstance(module, DeepseekOcr2TextTopkRouter): | ||
| init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) | ||
| elif isinstance(module, DeepseekOcr2TextExperts): | ||
| init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) | ||
| init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any ideas as to why this part was dropped ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because i added support for ue8m0 scales in finegrained-fp8 v3, this was needed for minimax m3 with the v2, but not anymore, it also wastes memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ue8m0 scales are a bit messy, some store them in the correct torch dtype, some store them in uint8, and some even store them in fp32 for no special reason 😭 i'm trying to tighten the contract and honor the config all the times because supporting all the on-disk variations would be more complicated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay ! Just to be sure, if we remove it now, it would not break existing checkpoints that are in
mxpf8format right ?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no they will work fine, even better because I just noticed that the fp32 scales are even avoiding the optimized mxfp8 path in https://github.com/huggingface/kernels-community/blob/aeb8ef0e09a132a6583c0a4c8b1096292922b54a/finegrained-fp8/torch-ext/finegrained_fp8/utils.py#L64 I also ran minimax m3 integration tests on the b200
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker