From 41e307830a7a2ab8fa6f3133245fb43642033bec Mon Sep 17 00:00:00 2001 From: Mario Michael Krell <172859788+mario-aws@users.noreply.github.com> Date: Sun, 30 Mar 2025 20:45:36 -0700 Subject: [PATCH] Correctly drop tokens in SwitchTransformer Previously, the identity function was used for dropped tokens with a weight from the expert that was not applied to the hidden states. This was misleading, because dropping means, the expert weight is zero. Instead of trying to fix the weight, we take an easier approach by initializing with zeros. Fixes issue https://github.com/huggingface/transformers/issues/37017 --- .../modeling_switch_transformers.py | 6 ++---- .../test_modeling_switch_transformers.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d2d9929b9128..d7a158e9e8d4 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -301,10 +301,8 @@ def forward(self, hidden_states): router_mask, router_probs, router_logits = self.router(hidden_states) expert_index = torch.argmax(router_mask, dim=-1) - # The routers introduced might not always map all the tokens, to a router, which means that some hidden states - # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the selected ones. - - next_states = hidden_states.clone() + # If a token gets dropped, we just set it to zero such that it does not get updated. + next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype) router_mask = router_mask.bool() batch_size, seq_len, num_experts = router_mask.shape diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 622c579843c8..967030b61ff7 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -43,6 +43,7 @@ SwitchTransformersEncoderModel, SwitchTransformersForConditionalGeneration, SwitchTransformersModel, + SwitchTransformersSparseMLP, SwitchTransformersTop1Router, ) from transformers.models.switch_transformers.modeling_switch_transformers import ( @@ -1134,3 +1135,16 @@ def test_small_batch_generate(self): for i in range(0, BATCH_SIZE, 2): self.assertEqual(batch_output[i], batch_output[i + 1]) + + +@require_torch +class SwitchTransformersSparseMLPTests(unittest.TestCase): + def test_token_dropping(self): + r""" + This test checks if the token dropping actually drops tokens. + """ + config = SwitchTransformersConfig(expert_capacity=0) # we drop everything + moe = SwitchTransformersSparseMLP(config) + dropped_token_results = moe(torch.randn(2, 3, 768))[0] + + assert (dropped_token_results == 0).all(), f"Some tokens not dropped: {dropped_token_results}."