-
Notifications
You must be signed in to change notification settings - Fork 33.6k
🚨 EP: fix EP router contract for many models + honor FP8 scale format #46818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 15 commits
777c272
122adb5
98034c9
f1e2235
83625bc
b34872e
0e0550a
4a585fe
0288a11
7d5976d
8ea20bf
d288a9e
d7a2fea
99f28f1
b8f8eed
4b04deb
c015bd2
6cf5856
c837936
9cd6feb
31f2139
5c2b2af
918d6c4
c54aee9
6de8bff
d1dcde0
54ab3ee
8d9846b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1458,6 +1458,12 @@ def tp_plan(self) -> dict[str, str]: | |
| The full tp plan for the model's modules | ||
| """ | ||
| if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel: | ||
| if not self._ep_plan: | ||
| raise ValueError( | ||
| f"Expert parallelism was requested (`enable_expert_parallel=True`), but " | ||
| f"`{self.__class__.__name__}` does not define an expert-parallel plan. Add a " | ||
| f"`base_model_ep_plan` to its config, or disable expert parallelism." | ||
| ) | ||
|
Comment on lines
+1470
to
+1475
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loud failure on missing ep plan |
||
| return self._ep_plan | ||
| return self._tp_plan | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1160,7 +1160,7 @@ class DeepseekOcr2TextExperts(nn.Module): | |
|
|
||
| def __init__(self, config): | ||
| super().__init__() | ||
| self.num_experts = config.n_routed_experts | ||
| self.num_experts = config.num_experts | ||
| self.hidden_dim = config.hidden_size | ||
| self.intermediate_dim = config.moe_intermediate_size | ||
| self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) | ||
|
|
@@ -1194,48 +1194,54 @@ def forward( | |
| return final_hidden_states | ||
|
|
||
|
|
||
| class DeepseekOcr2TextMoe(nn.Module): | ||
| class DeepseekOcr2TextTopkRouter(nn.Module): | ||
| def __init__(self, config: DeepseekOcr2TextConfig): | ||
| super().__init__() | ||
| self.config = config | ||
| self.experts = DeepseekOcr2TextExperts(config) | ||
| self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False) | ||
| if config.n_shared_experts is not None: | ||
| intermediate_size = config.moe_intermediate_size * config.n_shared_experts | ||
| self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size) | ||
| self.num_experts = config.n_routed_experts | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundancy in variables ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeh i guess we can drop n_routed_experts, removing it
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm so it seems to cascade into many models |
||
| self.top_k = config.num_experts_per_tok | ||
| self.routed_scaling_factor = config.routed_scaling_factor | ||
| self.topk_method = config.topk_method | ||
| self.num_group = config.n_group | ||
| self.top_k = config.num_experts_per_tok | ||
| self.topk_group = config.topk_group | ||
| self.weight = nn.Parameter(torch.empty((self.num_experts, config.hidden_size))) | ||
|
|
||
| def route_tokens_to_experts(self, router_logits): | ||
| batch_size, seq_len, hidden_dim = router_logits.shape | ||
| router_logits = router_logits.view(-1, hidden_dim) | ||
| router_logits = router_logits.softmax(dim=-1, dtype=torch.float32) | ||
| def forward(self, hidden_states): | ||
| hidden_states = hidden_states.view(-1, self.config.hidden_size) | ||
| router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) | ||
| scores = router_logits.softmax(dim=-1, dtype=torch.float32) | ||
| if self.topk_method == "greedy": | ||
| topk_weight, topk_idx = torch.topk(router_logits, k=self.top_k, dim=-1, sorted=False) | ||
| topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) | ||
| elif self.topk_method == "group_limited_greedy": | ||
| group_scores = router_logits.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values | ||
| group_scores = scores.view(-1, self.num_group, self.num_experts // self.num_group).max(dim=-1).values | ||
| group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] | ||
| group_mask = torch.zeros_like(group_scores) | ||
| group_mask.scatter_(1, group_idx, 1) | ||
| score_mask = ( | ||
| group_mask.unsqueeze(-1) | ||
| .expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group) | ||
| .reshape(batch_size * seq_len, -1) | ||
| .expand(-1, self.num_group, self.num_experts // self.num_group) | ||
| .reshape(-1, self.num_experts) | ||
| ) | ||
| tmp_scores = router_logits.masked_fill(~score_mask.bool(), 0.0) | ||
| topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) | ||
| scores = scores.masked_fill(~score_mask.bool(), 0.0) | ||
| topk_weights, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) | ||
| topk_weights = topk_weights * self.routed_scaling_factor | ||
| return router_logits, topk_weights, topk_indices | ||
|
|
||
| topk_weight = topk_weight * self.routed_scaling_factor | ||
| return topk_idx, topk_weight | ||
|
|
||
| class DeepseekOcr2TextMoe(nn.Module): | ||
| def __init__(self, config: DeepseekOcr2TextConfig): | ||
| super().__init__() | ||
| self.config = config | ||
| self.experts = DeepseekOcr2TextExperts(config) | ||
| self.gate = DeepseekOcr2TextTopkRouter(config) | ||
| self.shared_experts = DeepseekOcr2TextMLP( | ||
| config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts | ||
| ) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| residuals = hidden_states | ||
| orig_shape = hidden_states.shape | ||
| router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32)) | ||
| topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) | ||
| _, topk_weights, topk_indices = self.gate(hidden_states) | ||
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
| hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) | ||
| hidden_states = hidden_states + self.shared_experts(residuals) | ||
|
|
@@ -1330,7 +1336,9 @@ class DeepseekOcr2TextPreTrainedModel(PreTrainedModel): | |
| @torch.no_grad() | ||
| def _init_weights(self, module): | ||
| super()._init_weights(module) | ||
| if isinstance(module, DeepseekOcr2TextExperts): | ||
| if isinstance(module, DeepseekOcr2TextTopkRouter): | ||
| init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) | ||
| elif isinstance(module, DeepseekOcr2TextExperts): | ||
| init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) | ||
| init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any ideas as to why this part was dropped ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because i added support for ue8m0 scales in finegrained-fp8 v3, this was needed for minimax m3 with the v2, but not anymore, it also wastes memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ue8m0 scales are a bit messy, some store them in the correct torch dtype, some store them in uint8, and some even store them in fp32 for no special reason 😭 i'm trying to tighten the contract and honor the config all the times because supporting all the on-disk variations would be more complicated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay ! Just to be sure, if we remove it now, it would not break existing checkpoints that are in
mxpf8format right ?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no they will work fine, even better because I just noticed that the fp32 scales are even avoiding the optimized mxfp8 path in https://github.com/huggingface/kernels-community/blob/aeb8ef0e09a132a6583c0a4c8b1096292922b54a/finegrained-fp8/torch-ext/finegrained_fp8/utils.py#L64 I also ran minimax m3 integration tests on the b200
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep sounds good, just require the version of the kernel for that path to error out properly if kernel version not installed
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do pin the v3 in our lazy loading