Skip to content

simplify streaming diloco #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 15, 2025
Merged

simplify streaming diloco #233

merged 2 commits into from
Jul 15, 2025

Conversation

tushar00jain
Copy link
Contributor

@tushar00jain tushar00jain commented Jul 14, 2025

Summary:

  • since we made a simplifying assumption that we will only ever have 1 inflight fragment, we can simplify some of the logic particularly getting rid of the local step in manager state
  • we'll just use the manager's step to determine which fragment to sync
  • this also allows us to easily support heterogenous hardware by tuning the sync_every setting that will make slower/faster machines to perform less/more local steps before they sync
  • we can also perform quorum right before preparing a fragment sync - this easily ensures that all replicas will have the same max step and sync the same fragment
  • fix some numeric issues
    • the sign of the pseudogradient
    • inplace lerp when mixing local and global parameters

Stack created with Sapling. Best reviewed with ReviewStack.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 14, 2025
@tushar00jain tushar00jain requested review from d4l3k and H-Huang and removed request for d4l3k July 14, 2025 17:49
@tushar00jain tushar00jain force-pushed the pr233 branch 3 times, most recently from 6069b95 to 3d53cdc Compare July 14, 2025 18:42
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

pseudogradient = local_param - self.original_parameters[name].to(
p.device
pseudogradient = (
self.original_parameters[name].to(p.device) - local_param
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is flipped because we don't do 1- below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by 1-? the outer optimizer will do param = param - pseudo_grad. but loss goes down in the direction of -pseudo_grad. this is pretty much why we were seeing the loss going up earlier i think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proof of this:

  • let's say pseudo_grad = new_param - old_param where new_param = old_param - grad
  • outer optimizer step is new_param = old_param - pseudo_grad = old_param - (old_param - grad) = old_param + grad
  • this is incorrect because it should just be new_param = old_param - grad

@@ -588,7 +553,11 @@ def __init__(
if sync_every < len(model_fragments):
raise ValueError("Only 1 fragment can be syncrhonized at a time")

if fragment_sync_delay >= sync_every:
if sync_every % len(model_fragments) != 0:
raise ValueError("sync_every must divide the number of fragments")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense for now -- we can relax this later if it turns out people want to sync different parts of the model at different rates though that has other significant considerations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can still sync at different rates by passing in a different sync_every. will need to make it configurable in torchtitan too.

@tushar00jain tushar00jain force-pushed the pr233 branch 2 times, most recently from 606428d to 5d57f7f Compare July 15, 2025 17:29
Summary:
- move the training loop to a separate file
- convert it into a class so that methods can be overridden without having to duplicate code
Summary:
- since we made a simplifying assumption that we will only ever have 1 inflight fragment, we can simplify some of the logic particularly getting rid of the local step in manager state
- we'll just use the manager's step to determine which fragment to sync
- this also allows us to easily support heterogenous hardware by tuning the sync_every setting that will make slower/faster machines to perform less/more local steps before they sync
- we can also perform quorum right before preparing a fragment sync - this easily ensures that all replicas will have the same max step and sync the same fragment
- fix some numeric issues
  - the sign of the pseudogradient
  - inplace lerp when mixing local and global parameters
@tushar00jain tushar00jain merged commit 8170a4b into pytorch:main Jul 15, 2025
14 of 15 checks passed
@tushar00jain tushar00jain deleted the pr233 branch July 15, 2025 19:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants