-
Notifications
You must be signed in to change notification settings - Fork 442
Refactored auto-microbatching hook handles for FSDP #3843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
b126818
to
2a1cfb4
Compare
fa82a6b
to
647ad56
Compare
can we run an e2e test to verify it works? |
It seems hard to curate an e2e test to catch this failure that would be more informative than our current unit tests. We have these two unit tests: I guess in theory, we could create a larger example where a certain MPT module raises a CUDA OOM error at a certain epoch given a specific batch size or we could use a MPT module with a massive hidden layer in the FFN that will run into OOM for one batch size but not 1/2 of it... |
I meant just a general e2e test, does not have to trigger OOM |
Tested here: mpt-7b-fsdp2-p39FPR and compared it to base (mpt-7b-fsdp2-AKLNwv) and the numbers look good based on the tolerations mentioned in the regression testing PR @bowenyang008 (note that it defaults to 8 microbatch size when auto is set) |
fixed test issues formatted gated non-wrapped to FSDP1 updated for FSDP2 propagated changes to trainer added minor test fix formatted formatted once more addressed comments formatted minor fix
Refactored auto-microbatching hook handles for FSDP1 with additional documentation.
This PR was originally designed to support FSDP2 auto microbatching, but since there are additional issues with FSDP2 state there, we moved that to a draft PR: #3866