Skip to content

Refactored auto-microbatching hook handles for FSDP #3843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

rithwik-db
Copy link
Contributor

@rithwik-db rithwik-db commented May 2, 2025

Refactored auto-microbatching hook handles for FSDP1 with additional documentation.

This PR was originally designed to support FSDP2 auto microbatching, but since there are additional issues with FSDP2 state there, we moved that to a draft PR: #3866

@rithwik-db rithwik-db changed the title Added hook handles for FSDP2 to address automicrobatching [WIP] Added hook handles for FSDP2 to address automicrobatching May 2, 2025
@rithwik-db rithwik-db changed the title [WIP] Added hook handles for FSDP2 to address automicrobatching Added hook handles for FSDP2 to supported auto microbatching May 2, 2025
@rithwik-db rithwik-db changed the title Added hook handles for FSDP2 to supported auto microbatching Added hook handles for FSDP2 to support auto microbatching May 2, 2025
@rithwik-db rithwik-db force-pushed the hookhandles branch 2 times, most recently from b126818 to 2a1cfb4 Compare May 5, 2025 22:11
@rithwik-db rithwik-db force-pushed the hookhandles branch 2 times, most recently from fa82a6b to 647ad56 Compare May 21, 2025 00:32
@rithwik-db rithwik-db requested a review from bowenyang008 May 21, 2025 20:27
@bowenyang008
Copy link
Contributor

can we run an e2e test to verify it works?

@rithwik-db
Copy link
Contributor Author

can we run an e2e test to verify it works?

It seems hard to curate an e2e test to catch this failure that would be more informative than our current unit tests. We have these two unit tests:

  1. FSDP1
  2. FSDP2

I guess in theory, we could create a larger example where a certain MPT module raises a CUDA OOM error at a certain epoch given a specific batch size or we could use a MPT module with a massive hidden layer in the FFN that will run into OOM for one batch size but not 1/2 of it...

@rithwik-db rithwik-db requested a review from bowenyang008 May 22, 2025 06:50
@bowenyang008
Copy link
Contributor

can we run an e2e test to verify it works?

It seems hard to curate an e2e test to catch this failure that would be more informative than our current unit tests. We have these two unit tests:

  1. FSDP1
  2. FSDP2

I guess in theory, we could create a larger example where a certain MPT module raises a CUDA OOM error at a certain epoch given a specific batch size or we could use a MPT module with a massive hidden layer in the FFN that will run into OOM for one batch size but not 1/2 of it...

I meant just a general e2e test, does not have to trigger OOM

@rithwik-db
Copy link
Contributor Author

rithwik-db commented May 22, 2025

Tested here: mpt-7b-fsdp2-p39FPR and compared it to base (mpt-7b-fsdp2-AKLNwv) and the numbers look good based on the tolerations mentioned in the regression testing PR @bowenyang008 (note that it defaults to 8 microbatch size when auto is set)

@rithwik-db rithwik-db changed the title Added hook handles for FSDP2 to support auto microbatching Refactored auto-microbatching hook handles for FSDP May 29, 2025
fixed test issues

formatted

gated non-wrapped to FSDP1

updated for FSDP2

propagated changes to trainer

added minor test fix

formatted

formatted once more

addressed comments

formatted

minor fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants