You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a codebase forked from torchtitan with minor changes. FSDP trains very well with minimal instability, but HSDP on the same codebase exhibits loss spikes.
Is there some reason for this you folks can think of? Note that I have implemented gradient accumulation in my fork, though without changing any sharding behavior (just to accumulate the gradients on a larger batchsize)
The text was updated successfully, but these errors were encountered:
If it's not an HSDP bug (is it?), here are some things I'd look at:
For gradient accumulation, are you doing sum or average on the gradient? To ensure it's similar to mean cross-entropy loss backward on a larger batch size, you'll need to do average.
You can compare gradient accumulation vs. no grad accumulation on a small scale and see if numerics are stable. E.g. batch size 10 without grad accumulation and batch size 1 with grad accumulation for 10 iterations.
Check data loading behavior -- are you loading the same data each global batch with FSDP vs. HSDP? Although, even the data loading behaviors are different, as long as data is randomly distributed, I don't expect loss behavior to be very different.
but HSDP on the same codebase exhibits loss spikes
curious what the spike look like? maybe a plot of HSDP vs FSDP helps. I know after warm up, there is a spike. but would like to see if thare are spikes here and there all along the training
I have a codebase forked from torchtitan with minor changes. FSDP trains very well with minimal instability, but HSDP on the same codebase exhibits loss spikes.
Is there some reason for this you folks can think of? Note that I have implemented gradient accumulation in my fork, though without changing any sharding behavior (just to accumulate the gradients on a larger batchsize)
The text was updated successfully, but these errors were encountered: