-
Notifications
You must be signed in to change notification settings - Fork 37
simplify streaming diloco #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
6069b95
to
3d53cdc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
pseudogradient = local_param - self.original_parameters[name].to( | ||
p.device | ||
pseudogradient = ( | ||
self.original_parameters[name].to(p.device) - local_param |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is flipped because we don't do 1- below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean by 1-
? the outer optimizer will do param = param - pseudo_grad
. but loss goes down in the direction of -pseudo_grad
. this is pretty much why we were seeing the loss going up earlier i think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proof of this:
- let's say
pseudo_grad = new_param - old_param
wherenew_param = old_param - grad
- outer optimizer step is
new_param = old_param - pseudo_grad = old_param - (old_param - grad) = old_param + grad
- this is incorrect because it should just be
new_param = old_param - grad
@@ -588,7 +553,11 @@ def __init__( | |||
if sync_every < len(model_fragments): | |||
raise ValueError("Only 1 fragment can be syncrhonized at a time") | |||
|
|||
if fragment_sync_delay >= sync_every: | |||
if sync_every % len(model_fragments) != 0: | |||
raise ValueError("sync_every must divide the number of fragments") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense for now -- we can relax this later if it turns out people want to sync different parts of the model at different rates though that has other significant considerations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can still sync at different rates by passing in a different sync_every
. will need to make it configurable in torchtitan too.
606428d
to
5d57f7f
Compare
Summary: - move the training loop to a separate file - convert it into a class so that methods can be overridden without having to duplicate code
Summary: - since we made a simplifying assumption that we will only ever have 1 inflight fragment, we can simplify some of the logic particularly getting rid of the local step in manager state - we'll just use the manager's step to determine which fragment to sync - this also allows us to easily support heterogenous hardware by tuning the sync_every setting that will make slower/faster machines to perform less/more local steps before they sync - we can also perform quorum right before preparing a fragment sync - this easily ensures that all replicas will have the same max step and sync the same fragment - fix some numeric issues - the sign of the pseudogradient - inplace lerp when mixing local and global parameters
Summary:
Stack created with Sapling. Best reviewed with ReviewStack.