Skip to content

SwitchTransformer: Initialization of tensor to collect expert results is incorrect for dropped tokens (from ML POV) #37017

@mario-aws

Description

@mario-aws

System Info

This is a about a logical bug from ML point of view. It will not result in crashes but influence model behavior significantly.

In the transformers code of SwitchTransfomer, we initialize the vector for collecting expert results for an MLP with the hidden states and then update over index updates and eventual router probability scaling.

next_states = hidden_states.clone()
...
for idx in idx_mask:
            next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
                hidden_states[router_mask[:, :, idx]]
            )
hidden_states = router_probs * next_states

While this logic is fine for all tokens that are not dropped, it is wrong for dropped tokens. Setting the expert_capacity to zero for an extreme test where all tokens get dropped, one would get as output the original hidden states scaled by the probability of the respective expert they got never assigned to. Note that router_probs is not set to zero for dropped tokens. Also note, that at a different stage, we have the residual connection. It is not related to this part of the code. Why is this wrong from an ML POV?

  1. Dropping means that the tokens are not updated. Setting it to hidden state provides an "identity" expert update.
  2. The weight should correspond to the respective weight of the expert. Since the update does not get executed for this token the weight should be set to zero and not to the max weight for this token.
  3. If we have an expert with a lot of dropping, it would partially behave normal and partially behave like the identity function, which can be very different. From an ML point of view, we want to have only one behavior for the expert.
  4. Scaling of results can be quite different between an expert and an identity function.
  5. Whereas this error does not impact the expert weights, it has an influence on the router. The quality of dropped tokens is probably degraded and the expert gets a reduced weight. This could result in unexpected load balancing.

This has probably not be seen so far, because usually only very few tokens are intended to drop. So, changing this behavior will probably not have much impact in the grand scheme of things. A fix could look like:

next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)

transformers-cli env output (probably not relevant)

  • transformers version: 4.46.2
  • Platform: Linux-6.2.0-1018-aws-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.29.3
  • Safetensors version: 0.5.3
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.6.0+cpu (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I do not have a nice example yet but i should look something like:

config = SwitchTransformersConfig()
config.expert_capacity = 0
model = SwitchTransformersSparseMLP(config)
shape = (4096, 4096)
seed = 42
generator = torch.Generator().manual_seed(seed)
data = torch.randn(shape, generator=generator, dtype=torch.float32)
assert (model(data)[0] == 0).all(), "All tokens need to be properly dropped."

Expected behavior

The result of the module should be all zeroes, if all tokens are dropped and not some arbitrary scaling of the data.

assert (model(data)[0] == 0).all(), "All tokens need to be properly dropped."

Would be the respective assertion probably.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions