-
Notifications
You must be signed in to change notification settings - Fork 0
add MHSA module #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this at all?
Again, as in the other PRs (#4, #6), I would keep consistent to our existing implementation (see rwth-i6/returnn_common#58 for the initial code, but then also check the current code, also in the RETURNN frontend, also check with @mmz33). We discussed many things about the naming of modules and also their behavior.
In this specific case, you would use a standard generic RelPosSelfAttention
/RelPositionMultiHeadedAttention
. Or the standard PyTorch MultiheadAttention
.
Also, this module here would not add layer-norm, or the residual connection, or dropout. That would all be handled by the outer module (ConformerEncoderLayer
).
So, rather than reimplementing the already existing MultiheadAttention
, what is missing is sth like RelPositionMultiHeadedAttention
. We have that in RETURNN-common / RETURNN-frontend as well already, so I would suggest to just follow that implementation, unless there are good reasons to do sth differently. But I would keep at least the naming, arguments and variable names consistent.
Some tests should also be added, similar to #4. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see comments. Some minor things, but also some more fundamental.
Regarding the failing test I think you need to add torch
to the requirements.txt
since its not done in the main yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only one comment left from my side
…umentation on key_padding_mask
pretty much the same as in torchaudio.models.conformer.