Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HSDP causes loss instability #813

Open
apkumar opened this issue Jan 31, 2025 · 3 comments
Open

HSDP causes loss instability #813

apkumar opened this issue Jan 31, 2025 · 3 comments
Assignees
Labels
module: fsdp question Further information is requested

Comments

@apkumar
Copy link

apkumar commented Jan 31, 2025

I have a codebase forked from torchtitan with minor changes. FSDP trains very well with minimal instability, but HSDP on the same codebase exhibits loss spikes.

Is there some reason for this you folks can think of? Note that I have implemented gradient accumulation in my fork, though without changing any sharding behavior (just to accumulate the gradients on a larger batchsize)

@tianyu-l tianyu-l added the question Further information is requested label Jan 31, 2025
@tianyu-l
Copy link
Contributor

If it's not an HSDP bug (is it?), here are some things I'd look at:

  1. For gradient accumulation, are you doing sum or average on the gradient? To ensure it's similar to mean cross-entropy loss backward on a larger batch size, you'll need to do average.
  2. You can compare gradient accumulation vs. no grad accumulation on a small scale and see if numerics are stable. E.g. batch size 10 without grad accumulation and batch size 1 with grad accumulation for 10 iterations.
  3. Check data loading behavior -- are you loading the same data each global batch with FSDP vs. HSDP? Although, even the data loading behaviors are different, as long as data is randomly distributed, I don't expect loss behavior to be very different.

@gnadathur gnadathur assigned gnadathur, mori360 and weifengpy and unassigned gnadathur Feb 6, 2025
@gnadathur
Copy link
Contributor

cc: @weifengpy

@weifengpy
Copy link
Contributor

but HSDP on the same codebase exhibits loss spikes

curious what the spike look like? maybe a plot of HSDP vs FSDP helps. I know after warm up, there is a spike. but would like to see if thare are spikes here and there all along the training

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: fsdp question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants