Skip to content

Commit bde41d6

Browse files
authored
Correctly drop tokens in SwitchTransformer (#37123)
Previously, the identity function was used for dropped tokens with a weight from the expert that was not applied to the hidden states. This was misleading, because dropping means, the expert weight is zero. Instead of trying to fix the weight, we take an easier approach by initializing with zeros. Fixes issue #37017
1 parent 7ecc5b8 commit bde41d6

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/transformers/models/switch_transformers/modeling_switch_transformers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,8 @@ def forward(self, hidden_states):
301301
router_mask, router_probs, router_logits = self.router(hidden_states)
302302
expert_index = torch.argmax(router_mask, dim=-1)
303303

304-
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
305-
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the selected ones.
306-
307-
next_states = hidden_states.clone()
304+
# If a token gets dropped, we just set it to zero such that it does not get updated.
305+
next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)
308306

309307
router_mask = router_mask.bool()
310308
batch_size, seq_len, num_experts = router_mask.shape

tests/models/switch_transformers/test_modeling_switch_transformers.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
SwitchTransformersEncoderModel,
4343
SwitchTransformersForConditionalGeneration,
4444
SwitchTransformersModel,
45+
SwitchTransformersSparseMLP,
4546
SwitchTransformersTop1Router,
4647
)
4748
from transformers.models.switch_transformers.modeling_switch_transformers import (
@@ -1133,3 +1134,16 @@ def test_small_batch_generate(self):
11331134

11341135
for i in range(0, BATCH_SIZE, 2):
11351136
self.assertEqual(batch_output[i], batch_output[i + 1])
1137+
1138+
1139+
@require_torch
1140+
class SwitchTransformersSparseMLPTests(unittest.TestCase):
1141+
def test_token_dropping(self):
1142+
r"""
1143+
This test checks if the token dropping actually drops tokens.
1144+
"""
1145+
config = SwitchTransformersConfig(expert_capacity=0) # we drop everything
1146+
moe = SwitchTransformersSparseMLP(config)
1147+
dropped_token_results = moe(torch.randn(2, 3, 768))[0]
1148+
1149+
assert (dropped_token_results == 0).all(), f"Some tokens not dropped: {dropped_token_results}."

0 commit comments

Comments
 (0)