Skip to content

simplify streaming diloco #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 41 additions & 162 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ def _save_grads(self) -> None:
local_param = p.to_local()
else:
local_param = p
pseudogradient = local_param - self.original_parameters[name].to(
p.device
pseudogradient = (
self.original_parameters[name].to(p.device) - local_param
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is flipped because we don't do 1- below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean by 1-? the outer optimizer will do param = param - pseudo_grad. but loss goes down in the direction of -pseudo_grad. this is pretty much why we were seeing the loss going up earlier i think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proof of this:

  • let's say pseudo_grad = new_param - old_param where new_param = old_param - grad
  • outer optimizer step is new_param = old_param - pseudo_grad = old_param - (old_param - grad) = old_param + grad
  • this is incorrect because it should just be new_param = old_param - grad

)
self._grads[name] = pseudogradient

Expand Down Expand Up @@ -318,7 +318,7 @@ def _merge_parameters(self) -> None:
Merges the local and global parameters.
"""
for name, p in self._model_fragment.named_parameters():
p.data.lerp(self._local_parameters[name], 1 - self._fragment_update_alpha)
p.data.lerp_(self._local_parameters[name], self._fragment_update_alpha)

@torch.profiler.record_function("torchft::local_sgd::wait")
def wait(self) -> None:
Expand All @@ -335,20 +335,6 @@ def wait(self) -> None:

self._allreduce_futures = []

def should_prepare_fragment(self, step: int) -> bool:
"""
Determines if the fragment should be asynchronously sent to other replicas
"""
step_to_prepare = step - self._fragment_sync_offset
return step_to_prepare % self._sync_every == 0

def should_sync_fragment(self, step: int) -> bool:
"""
Determines if the fragment should be synchronized with other replicas
"""
step_to_sync = step - self._fragment_sync_offset - self._fragment_sync_delay
return step_to_sync % self._sync_every == 0

@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
def prepare_sync(self) -> None:
"""
Expand Down Expand Up @@ -384,27 +370,6 @@ def perform_sync(self) -> bool:
steps using the outer optimizer.
"""
# Waiting for an allreduce before it has been sent is currently not supported.
# Please make sure to not do this to avoid running into inconsistencies.
#
# This can happen when using large values of `fragment_sync_delay`.
# The node might not have participated in syncing of this fragment.
#
# The allreduce for other nodes who did might actually
# succeed and in that case, we shouldn't allow recovery
# from this node.
#
# We do need to increase the `max_step` here so we
# don't end up in an infinite loop of needing to recover
# but we can't let other nodes recover from this node
# because it doesn't have the latest state.
#
# We can add a `is_catching_up` flag to the state_dict
# to disallow recoveries from this node. Such nodes can
# be excluded from `max_step` calculation unless all
# nodes are catching up. This approach makes the replica state
# of global parameters diverge though. So we could add recovery
# for a particular fragment from a peer node as a part of the
# `should_commit` or next `quorum` when a node is catching up.
assert len(self._allreduce_futures) > 0

self.wait()
Expand Down Expand Up @@ -588,7 +553,11 @@ def __init__(
if sync_every < len(model_fragments):
raise ValueError("Only 1 fragment can be syncrhonized at a time")

if fragment_sync_delay >= sync_every:
if sync_every % len(model_fragments) != 0:
raise ValueError("sync_every must divide the number of fragments")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense for now -- we can relax this later if it turns out people want to sync different parts of the model at different rates though that has other significant considerations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can still sync at different rates by passing in a different sync_every. will need to make it configurable in torchtitan too.


self._sync_every: int = sync_every // len(model_fragments)
if fragment_sync_delay >= self._sync_every:
raise ValueError(
"Fragment must be synced before it is reduced another time"
)
Expand All @@ -599,23 +568,12 @@ def __init__(
super().__init__()
self._manager = manager

# Protects `_local_step`
self._lock = threading.Lock()

# The number of training iterations performed.
# Used to synchronize which fragment to send across all
# replicas
self._local_step = 0

# Sync `_local_step` with other replicas
self._manager.register_state_dict_fn(
"local_step",
self._load_step,
lambda: self._local_step,
)

# Used to perform quorum before any training happens
self._should_recover = True
self._fragment_sync_delay = fragment_sync_delay

self._hooks: List[RemovableHandle] = []

Expand Down Expand Up @@ -648,16 +606,9 @@ def __init__(
# `_StreamingDiLoCoFragment` about the fragment sync schedule.
assert fragment_sync_delay < sync_every // len(model_fragments)

# Used to ensure that we try to sync a fragment after we've sent a prepare for it
self._first_prepare_sent: set[int] = set()

# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

def _load_step(self, step: int) -> None:
with self._lock:
self._local_step = step

def _save_parameters(self) -> None:
for fragment in self._fragments:
fragment.save_parameters()
Expand Down Expand Up @@ -694,32 +645,19 @@ def _wait(self) -> None:
for fragment in self._fragments:
fragment.wait()

self._first_prepare_sent.clear()

def _quorum_loop(self) -> None:
def _current_fragment(self) -> int:
"""
Performs infinite retries until quorum is successfull
Determines which fragment to prepare/sync based on the current step.
"""
while True:
self._manager.start_quorum()

if self._manager.errored() is None:
return
step = self._manager.current_step()
return step % len(self._fragments)

def _step_post_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
if self._should_recover:
# Get the correct step when. This will continue after other committed.
self._quorum_loop()
self._should_recover = False
# This is to be consistent with the nodes that are not recovering. They
# proceed with the code below on the step after quorum completes.
return

# We need to make sure all nodes send the same fragments in order.
# This is to avoid deadlocking e.g.
#
Expand All @@ -730,91 +668,32 @@ def _step_post_hook(
#
# Both of them will fail because Node A didn't send fragment 2
# and Node B didn't send fragment 1.
with self._lock:
self._local_step += 1
step = self._local_step

# Start sending fragments
for i, fragment in enumerate(self._fragments):
if not fragment.should_prepare_fragment(step):
continue

logger.debug(f"preparing fragment {i} at step {step}")

self._first_prepare_sent.add(i)
fragment.prepare_sync()

for i, fragment in enumerate(self._fragments):
if not fragment.should_sync_fragment(step):
continue

# We need to have sent an allreduce before we can syncing
# a fragment
if i not in self._first_prepare_sent:
continue

logger.debug(f"syncing fragment {i} at step {step}")

if not fragment.perform_sync():
# Cancel all the previously scheduled allreduce by simply
# waiting for them. They should have failed but lets be
# paranoid anyway.
#
# We could choose to resend the failed fragments but that is
# more complicated since it involves coordinating all nodes to
# rewind and resend the fragments.
self._wait()

# Reset the local step. This is needed in case manager `should_commit` fails.
#
# This is because there can be a node that has the same `max_step` as the
# nodes that reached the commit point. However, this node failed before
# it could reach the commit point. So the local steps for these two nodes
# are not the same. But either can be used for recovery.
#
# To make sure both return the same step, we just reset the step to 0
# and start from scratch.
#
# In the happy path, we don't need to reset the step because --
# Nodes participating in the commit bumped their `max_step`.
# Any new nodes will take `local_step` from one of these nodes, which must
# be the same across all nodes because they took the same number of steps
# since the last commit to get to the most recent commit.
with self._lock:
self._local_step = 0

# Avoid doing allreduce after quorum failed.
#
# Maybe a different quorum formed without this node, so this node
# will incorrectly try to allreduce potentially on an incorrect
# fragment because the local_step is also out of sync.
# The replica will need recovery later anyway.
#
# So in case it didn't crash (e.g. network errors), we can save some
# training data by looping here. Otherwise that training data goes to
# waste after recovery
self._quorum_loop()

# TODO: Since we do quorum after commit, there might be a big gap until
# the next allreduce. This increases the chances of nodes failing
# and so the allreduce to fail.
# - We could maybe do a quorum again right before preparing for a fragment
# using `shrink_only`. This might make it tricky for new nodes to join
# though.
# - Maintain a sequence number in the state dict that gets bumped at every
# quorum call. Then we can do a quorum right before allreduce and avoid
# doing quorums after commit.

# We need to set make sure `_local_step` is still
# the same across all replicas if `quorum_id` changed.
#
# We can't garuntee a majority of replicas in this new quorum
# has the latest `max_step`.
#
# TODO: This is garuntee is currently lacking
# in torchft unless `shrink_only` is set.
self._local_step += 1

if self._local_step == self._sync_every - self._fragment_sync_delay:
# Time to prepare a fragment
#
# After the quorum though, everyone will have the same
# `local_step` because replicas with the chosen
# `max_step` will have the same `local_step`. That is
# because we don't take additional steps after commit.
# Some replicas will get the same copy of the model, implying batches
# can be overrepresented.
self._manager.start_quorum()
fragment = self._current_fragment()
self._fragments[fragment].prepare_sync()

if self._local_step < self._sync_every:
return

if self._local_step == self._sync_every:
# Time to sync a fragment
fragment = self._current_fragment()
self._fragments[fragment].perform_sync()

# If the allreduce truly failed, we'll keep retrying this fragment.
# We reset the parameters upon failure. We'll skip over some data
# but we won't over train before syncing.

self._local_step = 0
return

assert (
False
), f"{self._local_step=} should never be greater than {self._sync_every=}"
Loading
Loading