Skip to content

Correctly drop tokens in SwitchTransformer #37123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2025

Conversation

mario-aws
Copy link
Contributor

@mario-aws mario-aws commented Mar 31, 2025

What does this PR do?

Previously, the identity function was used for dropped tokens with a weight from the expert that was not applied to the hidden states. This was misleading, because dropping means, the expert weight is zero. Instead of trying to fix the weight, we take an easier approach by initializing with zeros.

Fixes #37017

Related work

https://github.com/tensorflow/mesh/blob/e6798a2610a2c2f4c4cd236d8214422cb1ecc00a/mesh_tensorflow/transformer/moe.py#L1144 mentions that it needs to be zeroed out.

https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py#L507C18-L507C31 combines the results without any clone initialization beforehand.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

text models: @ArthurZucker
last major changing person: @zucchini-nlp
person who requested PR: @Rocketknight1

@github-actions github-actions bot marked this pull request as draft March 31, 2025 03:54
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@marthos1
Copy link

marthos1 commented Apr 2, 2025

Would love to be part of this project one day.
I'm not a coder, but I do have a vision... :)

@mario-aws
Copy link
Contributor Author

@Rocketknight1

@Rocketknight1
Copy link
Member

LGTM but I'm not an expert on MoE routing! @zucchini-nlp @ArthurZucker if you're happy with it feel free to merge

@ydshieh ydshieh removed their request for review April 4, 2025 15:38
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks oke for me, as long as slow tests are green

r"""
This test checks if the token dropping actually drops tokens.
"""
config = SwitchTransformersConfig(expert_capacity=0) # we drop everything
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using SwitchTransformersConfig with defaults init a huge model , let's make it tiny. We can even move this under general model tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this config is used to init a single Expert-MLP module with 8 experts, hidden size 768, and intermediate size of 2048 and not a full model. There is no attention or multiple layers. In my case, this test ran super fast on CPU.

I could adjust the shapes to the ones chosen in SwitchTransformersModelTester but I don't think, it will have much impact.

I am new to HF testing and this code part. iIf I move this test to general model tests, I assume, I would have to initialize and run a whole model. In this case, I could not easily assert the result of the module, there would be some embedding, attentions, residual connection, and head that would influence the final result. The advantage of going with SwitchTransformersSparseMLP is that I know that the result needs to be all zeroes independent of the input (and input/model shapes).

@zucchini-nlp
Copy link
Member

run-slow: switch_transformers

Copy link
Contributor

github-actions bot commented Apr 7, 2025

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/switch_transformers']
quantizations: [] ...

Previously, the identity function was used for dropped tokens
with a weight from the expert that was not applied to the hidden states.
This was misleading, because dropping means, the expert weight is zero.
Instead of trying to fix the weight, we take an easier approach by initializing with zeros.

Fixes issue huggingface#37017
@mario-aws mario-aws force-pushed the switch_token_skipping_fix branch from f30dc00 to 8a79849 Compare April 8, 2025 14:22
@ydshieh
Copy link
Collaborator

ydshieh commented Apr 10, 2025

run-slow: switch_transformers

Copy link
Contributor

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/switch_transformers']
quantizations: [] ...

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 10, 2025

pytest/custom-tests — Slow CI job

It's green, so might be ready to go, @zucchini-nlp ?

@zucchini-nlp
Copy link
Member

Cool thanks! Merging

@zucchini-nlp zucchini-nlp merged commit bde41d6 into huggingface:main Apr 10, 2025
13 checks passed
cyr0930 pushed a commit to cyr0930/transformers that referenced this pull request Apr 18, 2025
Previously, the identity function was used for dropped tokens
with a weight from the expert that was not applied to the hidden states.
This was misleading, because dropping means, the expert weight is zero.
Instead of trying to fix the weight, we take an easier approach by initializing with zeros.

Fixes issue huggingface#37017
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
Previously, the identity function was used for dropped tokens
with a weight from the expert that was not applied to the hidden states.
This was misleading, because dropping means, the expert weight is zero.
Instead of trying to fix the weight, we take an easier approach by initializing with zeros.

Fixes issue huggingface#37017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SwitchTransformer: Initialization of tensor to collect expert results is incorrect for dropped tokens (from ML POV)
6 participants