Skip to content

Commit 6069b95

Browse files
committed
simplify streaming diloco
Summary: - since we made a simplifying assumption that we will only ever have 1 inflight fragment, we can simplify some of the logic particularly getting rid of the local step in manager state - we'll just use the manager's step to determine which fragment to sync - this also allows us to easily support heterogenous hardware by tuning the sync_every setting that will make slower/faster machines to perform less/more local steps before they sync - we can also perform quorum right before preparing a fragment sync - this easily ensures that all replicas will have the same max step and sync the same fragment - fix some numeric issues - the sign of the pseudogradient - inplace lerp when mixing local and global parameters
1 parent c14ee65 commit 6069b95

File tree

3 files changed

+50
-196
lines changed

3 files changed

+50
-196
lines changed

torchft/local_sgd.py

Lines changed: 44 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def _save_grads(self) -> None:
281281
local_param = p.to_local()
282282
else:
283283
local_param = p
284-
pseudogradient = local_param - self.original_parameters[name].to(
285-
p.device
284+
pseudogradient = (
285+
self.original_parameters[name].to(p.device) - local_param
286286
)
287287
self._grads[name] = pseudogradient
288288

@@ -318,7 +318,7 @@ def _merge_parameters(self) -> None:
318318
Merges the local and global parameters.
319319
"""
320320
for name, p in self._model_fragment.named_parameters():
321-
p.data.lerp(self._local_parameters[name], 1 - self._fragment_update_alpha)
321+
p.data.lerp_(self._local_parameters[name], self._fragment_update_alpha)
322322

323323
@torch.profiler.record_function("torchft::local_sgd::wait")
324324
def wait(self) -> None:
@@ -335,20 +335,6 @@ def wait(self) -> None:
335335

336336
self._allreduce_futures = []
337337

338-
def should_prepare_fragment(self, step: int) -> bool:
339-
"""
340-
Determines if the fragment should be asynchronously sent to other replicas
341-
"""
342-
step_to_prepare = step - self._fragment_sync_offset
343-
return step_to_prepare % self._sync_every == 0
344-
345-
def should_sync_fragment(self, step: int) -> bool:
346-
"""
347-
Determines if the fragment should be synchronized with other replicas
348-
"""
349-
step_to_sync = step - self._fragment_sync_offset - self._fragment_sync_delay
350-
return step_to_sync % self._sync_every == 0
351-
352338
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
353339
def prepare_sync(self) -> None:
354340
"""
@@ -384,27 +370,6 @@ def perform_sync(self) -> bool:
384370
steps using the outer optimizer.
385371
"""
386372
# Waiting for an allreduce before it has been sent is currently not supported.
387-
# Please make sure to not do this to avoid running into inconsistencies.
388-
#
389-
# This can happen when using large values of `fragment_sync_delay`.
390-
# The node might not have participated in syncing of this fragment.
391-
#
392-
# The allreduce for other nodes who did might actually
393-
# succeed and in that case, we shouldn't allow recovery
394-
# from this node.
395-
#
396-
# We do need to increase the `max_step` here so we
397-
# don't end up in an infinite loop of needing to recover
398-
# but we can't let other nodes recover from this node
399-
# because it doesn't have the latest state.
400-
#
401-
# We can add a `is_catching_up` flag to the state_dict
402-
# to disallow recoveries from this node. Such nodes can
403-
# be excluded from `max_step` calculation unless all
404-
# nodes are catching up. This approach makes the replica state
405-
# of global parameters diverge though. So we could add recovery
406-
# for a particular fragment from a peer node as a part of the
407-
# `should_commit` or next `quorum` when a node is catching up.
408373
assert len(self._allreduce_futures) > 0
409374

410375
self.wait()
@@ -588,7 +553,11 @@ def __init__(
588553
if sync_every < len(model_fragments):
589554
raise ValueError("Only 1 fragment can be syncrhonized at a time")
590555

591-
if fragment_sync_delay >= sync_every:
556+
if sync_every % len(model_fragments) != 0:
557+
raise ValueError("sync_every must divide the number of fragments")
558+
559+
self._sync_every: int = sync_every // len(model_fragments)
560+
if fragment_sync_delay >= self._sync_every:
592561
raise ValueError(
593562
"Fragment must be synced before it is reduced another time"
594563
)
@@ -599,23 +568,12 @@ def __init__(
599568
super().__init__()
600569
self._manager = manager
601570

602-
# Protects `_local_step`
603-
self._lock = threading.Lock()
604-
605571
# The number of training iterations performed.
606572
# Used to synchronize which fragment to send across all
607573
# replicas
608574
self._local_step = 0
609575

610-
# Sync `_local_step` with other replicas
611-
self._manager.register_state_dict_fn(
612-
"local_step",
613-
self._load_step,
614-
lambda: self._local_step,
615-
)
616-
617-
# Used to perform quorum before any training happens
618-
self._should_recover = True
576+
self._fragment_sync_delay = fragment_sync_delay
619577

620578
self._hooks: List[RemovableHandle] = []
621579

@@ -648,16 +606,9 @@ def __init__(
648606
# `_StreamingDiLoCoFragment` about the fragment sync schedule.
649607
assert fragment_sync_delay < sync_every // len(model_fragments)
650608

651-
# Used to ensure that we try to sync a fragment after we've sent a prepare for it
652-
self._first_prepare_sent: set[int] = set()
653-
654609
# Need to copy the parameters to the host to be safe if we are on the first step.
655610
self._save_parameters()
656611

657-
def _load_step(self, step: int) -> None:
658-
with self._lock:
659-
self._local_step = step
660-
661612
def _save_parameters(self) -> None:
662613
for fragment in self._fragments:
663614
fragment.save_parameters()
@@ -694,32 +645,19 @@ def _wait(self) -> None:
694645
for fragment in self._fragments:
695646
fragment.wait()
696647

697-
self._first_prepare_sent.clear()
698-
699-
def _quorum_loop(self) -> None:
648+
def _current_fragment(self) -> int:
700649
"""
701-
Performs infinite retries until quorum is successfull
650+
Determines which fragment to prepare/sync based on the current step.
702651
"""
703-
while True:
704-
self._manager.start_quorum()
705-
706-
if self._manager.errored() is None:
707-
return
652+
step = self._manager.current_step()
653+
return step % len(self._fragments)
708654

709655
def _step_post_hook(
710656
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
711657
) -> None:
712658
"""
713659
This hook is registered on the optimizer and is called after the optimizer step.
714660
"""
715-
if self._should_recover:
716-
# Get the correct step when. This will continue after other committed.
717-
self._quorum_loop()
718-
self._should_recover = False
719-
# This is to be consistent with the nodes that are not recovering. They
720-
# proceed with the code below on the step after quorum completes.
721-
return
722-
723661
# We need to make sure all nodes send the same fragments in order.
724662
# This is to avoid deadlocking e.g.
725663
#
@@ -730,91 +668,35 @@ def _step_post_hook(
730668
#
731669
# Both of them will fail because Node A didn't send fragment 2
732670
# and Node B didn't send fragment 1.
733-
with self._lock:
734-
self._local_step += 1
735-
step = self._local_step
736-
737-
# Start sending fragments
738-
for i, fragment in enumerate(self._fragments):
739-
if not fragment.should_prepare_fragment(step):
740-
continue
741-
742-
logger.debug(f"preparing fragment {i} at step {step}")
743-
744-
self._first_prepare_sent.add(i)
745-
fragment.prepare_sync()
746-
747-
for i, fragment in enumerate(self._fragments):
748-
if not fragment.should_sync_fragment(step):
749-
continue
750-
751-
# We need to have sent an allreduce before we can syncing
752-
# a fragment
753-
if i not in self._first_prepare_sent:
754-
continue
755-
756-
logger.debug(f"syncing fragment {i} at step {step}")
757-
758-
if not fragment.perform_sync():
759-
# Cancel all the previously scheduled allreduce by simply
760-
# waiting for them. They should have failed but lets be
761-
# paranoid anyway.
762-
#
763-
# We could choose to resend the failed fragments but that is
764-
# more complicated since it involves coordinating all nodes to
765-
# rewind and resend the fragments.
766-
self._wait()
767-
768-
# Reset the local step. This is needed in case manager `should_commit` fails.
769-
#
770-
# This is because there can be a node that has the same `max_step` as the
771-
# nodes that reached the commit point. However, this node failed before
772-
# it could reach the commit point. So the local steps for these two nodes
773-
# are not the same. But either can be used for recovery.
774-
#
775-
# To make sure both return the same step, we just reset the step to 0
776-
# and start from scratch.
777-
#
778-
# In the happy path, we don't need to reset the step because --
779-
# Nodes participating in the commit bumped their `max_step`.
780-
# Any new nodes will take `local_step` from one of these nodes, which must
781-
# be the same across all nodes because they took the same number of steps
782-
# since the last commit to get to the most recent commit.
783-
with self._lock:
784-
self._local_step = 0
785-
786-
# Avoid doing allreduce after quorum failed.
787-
#
788-
# Maybe a different quorum formed without this node, so this node
789-
# will incorrectly try to allreduce potentially on an incorrect
790-
# fragment because the local_step is also out of sync.
791-
# The replica will need recovery later anyway.
792-
#
793-
# So in case it didn't crash (e.g. network errors), we can save some
794-
# training data by looping here. Otherwise that training data goes to
795-
# waste after recovery
796-
self._quorum_loop()
797-
798-
# TODO: Since we do quorum after commit, there might be a big gap until
799-
# the next allreduce. This increases the chances of nodes failing
800-
# and so the allreduce to fail.
801-
# - We could maybe do a quorum again right before preparing for a fragment
802-
# using `shrink_only`. This might make it tricky for new nodes to join
803-
# though.
804-
# - Maintain a sequence number in the state dict that gets bumped at every
805-
# quorum call. Then we can do a quorum right before allreduce and avoid
806-
# doing quorums after commit.
807-
808-
# We need to set make sure `_local_step` is still
809-
# the same across all replicas if `quorum_id` changed.
810-
#
811-
# We can't garuntee a majority of replicas in this new quorum
812-
# has the latest `max_step`.
813-
#
814-
# TODO: This is garuntee is currently lacking
815-
# in torchft unless `shrink_only` is set.
671+
self._local_step += 1
672+
673+
if self._local_step == self._sync_every - self._fragment_sync_delay:
674+
# Time to prepare a fragment
816675
#
817-
# After the quorum though, everyone will have the same
818-
# `local_step` because replicas with the chosen
819-
# `max_step` will have the same `local_step`. That is
820-
# because we don't take additional steps after commit.
676+
# Some replicas will get the same copy of the model, implying batches
677+
# can be overrepresented.
678+
self._manager.start_quorum()
679+
fragment = self._current_fragment()
680+
print("<DEBUG>")
681+
print(fragment)
682+
print(len(self._fragments))
683+
self._fragments[fragment].prepare_sync()
684+
685+
if self._local_step < self._sync_every:
686+
return
687+
688+
if self._local_step == self._sync_every:
689+
# Time to sync a fragment
690+
fragment = self._current_fragment()
691+
self._fragments[fragment].perform_sync()
692+
693+
# If the allreduce truly failed, we'll keep retrying this fragment.
694+
# We reset the parameters upon failure. We'll skip over some data
695+
# but we won't over train before syncing.
696+
697+
self._local_step = 0
698+
return
699+
700+
assert (
701+
False
702+
), f"{self._local_step=} should never be greater than {self._sync_every=}"

torchft/local_sgd_integ_test.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -418,16 +418,6 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
418418

419419
assert_equal_global_state(rep1, rep0)
420420

421-
for step in rep1.keys():
422-
if step == 2:
423-
# Replica 0 should have reset its `local_step` after failure
424-
self.assertEqual(rep1[step]["user"]["local_step"], 0)
425-
self.assertEqual(rep0[step]["user"]["local_step"], 5)
426-
else:
427-
self.assertEqual(
428-
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
429-
)
430-
431421
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
432422

433423
CONFIG: list[tuple[bool, int, int, float]] = [
@@ -509,14 +499,6 @@ def test_streaming_diloco_upscale(
509499
assert_equal_global_state(rep0, rep1)
510500
assert_equal_global_state(rep0, rep2)
511501

512-
for step in rep0.keys():
513-
self.assertEqual(
514-
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
515-
)
516-
self.assertEqual(
517-
rep1[step]["user"]["local_step"], rep2[step]["user"]["local_step"]
518-
)
519-
520502
for event_injector in event_injectors:
521503
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1)
522504

@@ -586,11 +568,6 @@ def test_streaming_diloco_commit_failure(
586568

587569
assert_equal_global_state(rep0, rep1)
588570

589-
for step in rep0.keys():
590-
self.assertEqual(
591-
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
592-
)
593-
594571
for event_injector in event_injectors:
595572
self.assertEqual(
596573
event_injector.count[EventInjectorEvent.AllreduceFailure], 1

torchft/local_sgd_test.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,15 @@ def test_diloco_healthy(self) -> None:
157157
loss.backward()
158158
inner_optimizer.step()
159159

160-
self.assertEqual(diloco._local_step, 0)
161-
loss = model(inp).mean()
162-
loss.backward()
163-
inner_optimizer.step()
164-
165160
self.assertEqual(diloco._local_step, 1)
166-
self.assertEqual(manager.start_quorum.call_count, 1)
161+
manager.current_step.return_value = 0
162+
manager.should_commit.return_value = True
167163
loss = model(inp).mean()
168164
loss.backward()
169165
inner_optimizer.step()
170-
self.assertEqual(manager.start_quorum.call_count, 2)
171166

172-
manager.should_commit.return_value = True
173-
self.assertEqual(diloco._local_step, 2)
167+
self.assertEqual(diloco._local_step, 0)
168+
self.assertEqual(manager.start_quorum.call_count, 1)
174169
torch.testing.assert_close(
175170
diloco._fragments[0].original_parameters, _params_dict(model)
176171
)
@@ -320,8 +315,8 @@ def fake_allreduce(
320315
diloco._fragments[0]._set_grads()
321316

322317
# we added 2 to the parameters, then multiplied the gradients by 2
323-
# so we should expect the model's gradient to be 4
324-
expected_grad = 4
318+
# so we should expect the model's gradient to be -4
319+
expected_grad = -4
325320
for param in model.parameters():
326321
assert param.grad is not None
327322
t = torch.empty_like(param.grad)

0 commit comments

Comments
 (0)