-
Notifications
You must be signed in to change notification settings - Fork 37
simplify streaming diloco #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -281,8 +281,8 @@ def _save_grads(self) -> None: | |
local_param = p.to_local() | ||
else: | ||
local_param = p | ||
pseudogradient = local_param - self.original_parameters[name].to( | ||
p.device | ||
pseudogradient = ( | ||
self.original_parameters[name].to(p.device) - local_param | ||
) | ||
self._grads[name] = pseudogradient | ||
|
||
|
@@ -318,7 +318,7 @@ def _merge_parameters(self) -> None: | |
Merges the local and global parameters. | ||
""" | ||
for name, p in self._model_fragment.named_parameters(): | ||
p.data.lerp(self._local_parameters[name], 1 - self._fragment_update_alpha) | ||
p.data.lerp_(self._local_parameters[name], self._fragment_update_alpha) | ||
|
||
@torch.profiler.record_function("torchft::local_sgd::wait") | ||
def wait(self) -> None: | ||
|
@@ -335,20 +335,6 @@ def wait(self) -> None: | |
|
||
self._allreduce_futures = [] | ||
|
||
def should_prepare_fragment(self, step: int) -> bool: | ||
""" | ||
Determines if the fragment should be asynchronously sent to other replicas | ||
""" | ||
step_to_prepare = step - self._fragment_sync_offset | ||
return step_to_prepare % self._sync_every == 0 | ||
|
||
def should_sync_fragment(self, step: int) -> bool: | ||
""" | ||
Determines if the fragment should be synchronized with other replicas | ||
""" | ||
step_to_sync = step - self._fragment_sync_offset - self._fragment_sync_delay | ||
return step_to_sync % self._sync_every == 0 | ||
|
||
@torch.profiler.record_function("torchft::local_sgd::prepare_sync") | ||
def prepare_sync(self) -> None: | ||
""" | ||
|
@@ -384,27 +370,6 @@ def perform_sync(self) -> bool: | |
steps using the outer optimizer. | ||
""" | ||
# Waiting for an allreduce before it has been sent is currently not supported. | ||
# Please make sure to not do this to avoid running into inconsistencies. | ||
# | ||
# This can happen when using large values of `fragment_sync_delay`. | ||
# The node might not have participated in syncing of this fragment. | ||
# | ||
# The allreduce for other nodes who did might actually | ||
# succeed and in that case, we shouldn't allow recovery | ||
# from this node. | ||
# | ||
# We do need to increase the `max_step` here so we | ||
# don't end up in an infinite loop of needing to recover | ||
# but we can't let other nodes recover from this node | ||
# because it doesn't have the latest state. | ||
# | ||
# We can add a `is_catching_up` flag to the state_dict | ||
# to disallow recoveries from this node. Such nodes can | ||
# be excluded from `max_step` calculation unless all | ||
# nodes are catching up. This approach makes the replica state | ||
# of global parameters diverge though. So we could add recovery | ||
# for a particular fragment from a peer node as a part of the | ||
# `should_commit` or next `quorum` when a node is catching up. | ||
assert len(self._allreduce_futures) > 0 | ||
|
||
self.wait() | ||
|
@@ -588,7 +553,11 @@ def __init__( | |
if sync_every < len(model_fragments): | ||
raise ValueError("Only 1 fragment can be syncrhonized at a time") | ||
|
||
if fragment_sync_delay >= sync_every: | ||
if sync_every % len(model_fragments) != 0: | ||
raise ValueError("sync_every must divide the number of fragments") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes sense for now -- we can relax this later if it turns out people want to sync different parts of the model at different rates though that has other significant considerations There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can still sync at different rates by passing in a different |
||
|
||
self._sync_every: int = sync_every // len(model_fragments) | ||
if fragment_sync_delay >= self._sync_every: | ||
raise ValueError( | ||
"Fragment must be synced before it is reduced another time" | ||
) | ||
|
@@ -599,23 +568,12 @@ def __init__( | |
super().__init__() | ||
self._manager = manager | ||
|
||
# Protects `_local_step` | ||
self._lock = threading.Lock() | ||
|
||
# The number of training iterations performed. | ||
# Used to synchronize which fragment to send across all | ||
# replicas | ||
self._local_step = 0 | ||
|
||
# Sync `_local_step` with other replicas | ||
self._manager.register_state_dict_fn( | ||
"local_step", | ||
self._load_step, | ||
lambda: self._local_step, | ||
) | ||
|
||
# Used to perform quorum before any training happens | ||
self._should_recover = True | ||
self._fragment_sync_delay = fragment_sync_delay | ||
|
||
self._hooks: List[RemovableHandle] = [] | ||
|
||
|
@@ -648,16 +606,9 @@ def __init__( | |
# `_StreamingDiLoCoFragment` about the fragment sync schedule. | ||
assert fragment_sync_delay < sync_every // len(model_fragments) | ||
|
||
# Used to ensure that we try to sync a fragment after we've sent a prepare for it | ||
self._first_prepare_sent: set[int] = set() | ||
|
||
# Need to copy the parameters to the host to be safe if we are on the first step. | ||
self._save_parameters() | ||
|
||
def _load_step(self, step: int) -> None: | ||
with self._lock: | ||
self._local_step = step | ||
|
||
def _save_parameters(self) -> None: | ||
for fragment in self._fragments: | ||
fragment.save_parameters() | ||
|
@@ -694,32 +645,19 @@ def _wait(self) -> None: | |
for fragment in self._fragments: | ||
fragment.wait() | ||
|
||
self._first_prepare_sent.clear() | ||
|
||
def _quorum_loop(self) -> None: | ||
def _current_fragment(self) -> int: | ||
""" | ||
Performs infinite retries until quorum is successfull | ||
Determines which fragment to prepare/sync based on the current step. | ||
""" | ||
while True: | ||
self._manager.start_quorum() | ||
|
||
if self._manager.errored() is None: | ||
return | ||
step = self._manager.current_step() | ||
return step % len(self._fragments) | ||
|
||
def _step_post_hook( | ||
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] | ||
) -> None: | ||
""" | ||
This hook is registered on the optimizer and is called after the optimizer step. | ||
""" | ||
if self._should_recover: | ||
# Get the correct step when. This will continue after other committed. | ||
self._quorum_loop() | ||
self._should_recover = False | ||
# This is to be consistent with the nodes that are not recovering. They | ||
# proceed with the code below on the step after quorum completes. | ||
return | ||
|
||
# We need to make sure all nodes send the same fragments in order. | ||
# This is to avoid deadlocking e.g. | ||
# | ||
|
@@ -730,91 +668,32 @@ def _step_post_hook( | |
# | ||
# Both of them will fail because Node A didn't send fragment 2 | ||
# and Node B didn't send fragment 1. | ||
with self._lock: | ||
self._local_step += 1 | ||
step = self._local_step | ||
|
||
# Start sending fragments | ||
for i, fragment in enumerate(self._fragments): | ||
if not fragment.should_prepare_fragment(step): | ||
continue | ||
|
||
logger.debug(f"preparing fragment {i} at step {step}") | ||
|
||
self._first_prepare_sent.add(i) | ||
fragment.prepare_sync() | ||
|
||
for i, fragment in enumerate(self._fragments): | ||
if not fragment.should_sync_fragment(step): | ||
continue | ||
|
||
# We need to have sent an allreduce before we can syncing | ||
# a fragment | ||
if i not in self._first_prepare_sent: | ||
continue | ||
|
||
logger.debug(f"syncing fragment {i} at step {step}") | ||
|
||
if not fragment.perform_sync(): | ||
# Cancel all the previously scheduled allreduce by simply | ||
# waiting for them. They should have failed but lets be | ||
# paranoid anyway. | ||
# | ||
# We could choose to resend the failed fragments but that is | ||
# more complicated since it involves coordinating all nodes to | ||
# rewind and resend the fragments. | ||
self._wait() | ||
|
||
# Reset the local step. This is needed in case manager `should_commit` fails. | ||
# | ||
# This is because there can be a node that has the same `max_step` as the | ||
# nodes that reached the commit point. However, this node failed before | ||
# it could reach the commit point. So the local steps for these two nodes | ||
# are not the same. But either can be used for recovery. | ||
# | ||
# To make sure both return the same step, we just reset the step to 0 | ||
# and start from scratch. | ||
# | ||
# In the happy path, we don't need to reset the step because -- | ||
# Nodes participating in the commit bumped their `max_step`. | ||
# Any new nodes will take `local_step` from one of these nodes, which must | ||
# be the same across all nodes because they took the same number of steps | ||
# since the last commit to get to the most recent commit. | ||
with self._lock: | ||
self._local_step = 0 | ||
|
||
# Avoid doing allreduce after quorum failed. | ||
# | ||
# Maybe a different quorum formed without this node, so this node | ||
# will incorrectly try to allreduce potentially on an incorrect | ||
# fragment because the local_step is also out of sync. | ||
# The replica will need recovery later anyway. | ||
# | ||
# So in case it didn't crash (e.g. network errors), we can save some | ||
# training data by looping here. Otherwise that training data goes to | ||
# waste after recovery | ||
self._quorum_loop() | ||
|
||
# TODO: Since we do quorum after commit, there might be a big gap until | ||
# the next allreduce. This increases the chances of nodes failing | ||
# and so the allreduce to fail. | ||
# - We could maybe do a quorum again right before preparing for a fragment | ||
# using `shrink_only`. This might make it tricky for new nodes to join | ||
# though. | ||
# - Maintain a sequence number in the state dict that gets bumped at every | ||
# quorum call. Then we can do a quorum right before allreduce and avoid | ||
# doing quorums after commit. | ||
|
||
# We need to set make sure `_local_step` is still | ||
# the same across all replicas if `quorum_id` changed. | ||
# | ||
# We can't garuntee a majority of replicas in this new quorum | ||
# has the latest `max_step`. | ||
# | ||
# TODO: This is garuntee is currently lacking | ||
# in torchft unless `shrink_only` is set. | ||
self._local_step += 1 | ||
|
||
if self._local_step == self._sync_every - self._fragment_sync_delay: | ||
# Time to prepare a fragment | ||
# | ||
# After the quorum though, everyone will have the same | ||
# `local_step` because replicas with the chosen | ||
# `max_step` will have the same `local_step`. That is | ||
# because we don't take additional steps after commit. | ||
# Some replicas will get the same copy of the model, implying batches | ||
# can be overrepresented. | ||
self._manager.start_quorum() | ||
fragment = self._current_fragment() | ||
self._fragments[fragment].prepare_sync() | ||
|
||
if self._local_step < self._sync_every: | ||
return | ||
|
||
if self._local_step == self._sync_every: | ||
# Time to sync a fragment | ||
fragment = self._current_fragment() | ||
self._fragments[fragment].perform_sync() | ||
|
||
# If the allreduce truly failed, we'll keep retrying this fragment. | ||
# We reset the parameters upon failure. We'll skip over some data | ||
# but we won't over train before syncing. | ||
|
||
self._local_step = 0 | ||
return | ||
|
||
assert ( | ||
False | ||
), f"{self._local_step=} should never be greater than {self._sync_every=}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is flipped because we don't do 1- below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean by
1-
? the outer optimizer will doparam = param - pseudo_grad
. but loss goes down in the direction of-pseudo_grad
. this is pretty much why we were seeing the loss going up earlier i think.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proof of this:
pseudo_grad = new_param - old_param
wherenew_param = old_param - grad
new_param = old_param - pseudo_grad = old_param - (old_param - grad) = old_param + grad
new_param = old_param - grad