Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly drop tokens in SwitchTransformer #37123

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,8 @@ def forward(self, hidden_states):
router_mask, router_probs, router_logits = self.router(hidden_states)
expert_index = torch.argmax(router_mask, dim=-1)

# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the selected ones.

next_states = hidden_states.clone()
# If a token gets dropped, we just set it to zero such that it does not get updated.
next_states = torch.zeros(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype)

router_mask = router_mask.bool()
batch_size, seq_len, num_experts = router_mask.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
SwitchTransformersEncoderModel,
SwitchTransformersForConditionalGeneration,
SwitchTransformersModel,
SwitchTransformersSparseMLP,
SwitchTransformersTop1Router,
)
from transformers.models.switch_transformers.modeling_switch_transformers import (
Expand Down Expand Up @@ -1134,3 +1135,16 @@ def test_small_batch_generate(self):

for i in range(0, BATCH_SIZE, 2):
self.assertEqual(batch_output[i], batch_output[i + 1])


@require_torch
class SwitchTransformersSparseMLPTests(unittest.TestCase):
def test_token_dropping(self):
r"""
This test checks if the token dropping actually drops tokens.
"""
config = SwitchTransformersConfig(expert_capacity=0) # we drop everything
moe = SwitchTransformersSparseMLP(config)
dropped_token_results = moe(torch.randn(2, 3, 768))[0]

assert (dropped_token_results == 0).all(), f"Some tokens not dropped: {dropped_token_results}."