Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SwitchTransformer: Initialization of tensor to collect expert results is incorrect for dropped tokens (from ML POV) #37017

Open
2 of 4 tasks
mario-aws opened this issue Mar 26, 2025 · 3 comments · May be fixed by #37123
Open
2 of 4 tasks
Labels

Comments

@mario-aws
Copy link

System Info

This is a about a logical bug from ML point of view. It will not result in crashes but influence model behavior significantly.

In the transformers code of SwitchTransfomer, we initialize the vector for collecting expert results for an MLP with the hidden states and then update over index updates and eventual router probability scaling.

next_states = hidden_states.clone()
...
for idx in idx_mask:
            next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
                hidden_states[router_mask[:, :, idx]]
            )
hidden_states = router_probs * next_states

While this logic is fine for all tokens that are not dropped, it is wrong for dropped tokens. Setting the expert_capacity to zero for an extreme test where all tokens get dropped, one would get as output the original hidden states scaled by the probability of the respective expert they got never assigned to. Note that router_probs is not set to zero for dropped tokens. Also note, that at a different stage, we have the residual connection. It is not related to this part of the code. Why is this wrong from an ML POV?

  1. Dropping means that the tokens are not updated. Setting it to hidden state provides an "identity" expert update.
  2. The weight should correspond to the respective weight of the expert. Since the update does not get executed for this token the weight should be set to zero and not to the max weight for this token.
  3. If we have an expert with a lot of dropping, it would partially behave normal and partially behave like the identity function, which can be very different. From an ML point of view, we want to have only one behavior for the expert.
  4. Scaling of results can be quite different between an expert and an identity function.
  5. Whereas this error does not impact the expert weights, it has an influence on the router. The quality of dropped tokens is probably degraded and the expert gets a reduced weight. This could result in unexpected load balancing.

This has probably not be seen so far, because usually only very few tokens are intended to drop. So, changing this behavior will probably not have much impact in the grand scheme of things. A fix could look like:

next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)

transformers-cli env output (probably not relevant)

  • transformers version: 4.46.2
  • Platform: Linux-6.2.0-1018-aws-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.29.3
  • Safetensors version: 0.5.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.6.0+cpu (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I do not have a nice example yet but i should look something like:

config = SwitchTransformersConfig()
config.expert_capacity = 0
model = SwitchTransformersSparseMLP(config)
shape = (4096, 4096)
seed = 42
generator = torch.Generator().manual_seed(seed)
data = torch.randn(shape, generator=generator, dtype=torch.float32)
assert (model(data)[0] == 0).all(), "All tokens need to be properly dropped."

Expected behavior

The result of the module should be all zeroes, if all tokens are dropped and not some arbitrary scaling of the data.

assert (model(data)[0] == 0).all(), "All tokens need to be properly dropped."

Would be the respective assertion probably.

@mario-aws mario-aws added the bug label Mar 26, 2025
@Rocketknight1
Copy link
Member

Rocketknight1 commented Mar 27, 2025

Hi @mario-aws this seems legit - but unfortunately (for you) I think you might be the most qualified person to understand it and fix the bug at this point! Would you be willing to make the fix in a PR, and maybe add an example/test to show the difference?

@mario-aws
Copy link
Author

mario-aws commented Mar 28, 2025

Thanks for the quick feedback. I will familiarize myself more with the HF code base and PR procedures and provide a PR. In the meantime, let me share the simple reproducer test, that I have in mind. The fix is a one-liner.

from torch import nn
from transformers import SwitchTransformersConfig
from transformers.models.switch_transformers.modeling_switch_transformers import (
    SwitchTransformersDenseActDense,
    SwitchTransformersTop1Router,
)

import torch

USE_FIX = False


class SwitchTransformersSparseMLP(nn.Module):
    r"""
    Implementation of the Switch Transformers Sparse MLP module.

    Copy from SwitchTransformersSparseMLP in switch_transformers to fix `next_states` init.
    """

    def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense):
        super().__init__()
        # Step 1: Get the correct router according to its class
        self.router = SwitchTransformersTop1Router(config)

        # Step 2: Get the experts
        self.experts = nn.ModuleDict()
        for idx in range(config.num_experts):
            self.experts[f"expert_{idx}"] = expert_class(config)

    def forward(self, hidden_states):
        r"""
        Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

        1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
        and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
        hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).

        2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
        expert the corresponding hidden states.

        """
        # Step 1: Get the router_mask from the router as wel as the probabilities
        router_mask, router_probs, router_logits = self.router(hidden_states)
        expert_index = torch.argmax(router_mask, dim=-1)

        if USE_FIX:
            # If a token gets dropped, we just set it to zero such that it does not get updated.
            next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)
        else:
            # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
            # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
            next_states = hidden_states.clone()

        router_mask = router_mask.bool()
        batch_size, seq_len, num_experts = router_mask.shape
        idx_mask = router_mask.reshape(batch_size * seq_len, num_experts).sum(dim=0)
        idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
            0
        ].tolist()  # length: number of "activated" expert / value: index
        for idx in idx_mask:
            next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
                hidden_states[router_mask[:, :, idx]]
            )

        hidden_states = router_probs * next_states
        return hidden_states, (router_logits, expert_index)


config = SwitchTransformersConfig(expert_capacity=0)
moe = SwitchTransformersSparseMLP(config)
dropped_token_results = moe(torch.randn(2, 3, 768))[0]

assert (dropped_token_results == 0).all(), f"Some tokens not dropped: {dropped_token_results}."

mario-aws added a commit to mario-aws/transformers_switch_fix that referenced this issue Mar 31, 2025
Previously, the identity function was used for dropped tokens
with a weight from the expert that was not applied to the hidden states.
This was leading, because dropping means, the expert weight is zero.
Instead of trying to fix the weight, we take an easier approach by initializing with zeros.

Fixes issue huggingface#37017
mario-aws added a commit to mario-aws/transformers_switch_fix that referenced this issue Mar 31, 2025
Previously, the identity function was used for dropped tokens
with a weight from the expert that was not applied to the hidden states.
This was misleading, because dropping means, the expert weight is zero.
Instead of trying to fix the weight, we take an easier approach by initializing with zeros.

Fixes issue huggingface#37017
@mario-aws mario-aws linked a pull request Mar 31, 2025 that will close this issue
5 tasks
@mario-aws
Copy link
Author

I created a respective PR: #37123

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants