-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SwitchTransformer: Initialization of tensor to collect expert results is incorrect for dropped tokens (from ML POV) #37017
Comments
Hi @mario-aws this seems legit - but unfortunately (for you) I think you might be the most qualified person to understand it and fix the bug at this point! Would you be willing to make the fix in a PR, and maybe add an example/test to show the difference? |
Thanks for the quick feedback. I will familiarize myself more with the HF code base and PR procedures and provide a PR. In the meantime, let me share the simple reproducer test, that I have in mind. The fix is a one-liner.
|
Previously, the identity function was used for dropped tokens with a weight from the expert that was not applied to the hidden states. This was leading, because dropping means, the expert weight is zero. Instead of trying to fix the weight, we take an easier approach by initializing with zeros. Fixes issue huggingface#37017
Previously, the identity function was used for dropped tokens with a weight from the expert that was not applied to the hidden states. This was misleading, because dropping means, the expert weight is zero. Instead of trying to fix the weight, we take an easier approach by initializing with zeros. Fixes issue huggingface#37017
I created a respective PR: #37123 |
System Info
This is a about a logical bug from ML point of view. It will not result in crashes but influence model behavior significantly.
In the transformers code of SwitchTransfomer, we initialize the vector for collecting expert results for an MLP with the hidden states and then update over index updates and eventual router probability scaling.
While this logic is fine for all tokens that are not dropped, it is wrong for dropped tokens. Setting the expert_capacity to zero for an extreme test where all tokens get dropped, one would get as output the original hidden states scaled by the probability of the respective expert they got never assigned to. Note that router_probs is not set to zero for dropped tokens. Also note, that at a different stage, we have the residual connection. It is not related to this part of the code. Why is this wrong from an ML POV?
This has probably not be seen so far, because usually only very few tokens are intended to drop. So, changing this behavior will probably not have much impact in the grand scheme of things. A fix could look like:
transformers-cli env output (probably not relevant)
transformers
version: 4.46.2Who can help?
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I do not have a nice example yet but i should look something like:
Expected behavior
The result of the module should be all zeroes, if all tokens are dropped and not some arbitrary scaling of the data.
Would be the respective assertion probably.
The text was updated successfully, but these errors were encountered: