@@ -281,8 +281,8 @@ def _save_grads(self) -> None:
281
281
local_param = p .to_local ()
282
282
else :
283
283
local_param = p
284
- pseudogradient = local_param - self . original_parameters [ name ]. to (
285
- p .device
284
+ pseudogradient = (
285
+ self . original_parameters [ name ]. to ( p .device ) - local_param
286
286
)
287
287
self ._grads [name ] = pseudogradient
288
288
@@ -318,7 +318,7 @@ def _merge_parameters(self) -> None:
318
318
Merges the local and global parameters.
319
319
"""
320
320
for name , p in self ._model_fragment .named_parameters ():
321
- p .data .lerp (self ._local_parameters [name ], 1 - self ._fragment_update_alpha )
321
+ p .data .lerp_ (self ._local_parameters [name ], self ._fragment_update_alpha )
322
322
323
323
@torch .profiler .record_function ("torchft::local_sgd::wait" )
324
324
def wait (self ) -> None :
@@ -335,20 +335,6 @@ def wait(self) -> None:
335
335
336
336
self ._allreduce_futures = []
337
337
338
- def should_prepare_fragment (self , step : int ) -> bool :
339
- """
340
- Determines if the fragment should be asynchronously sent to other replicas
341
- """
342
- step_to_prepare = step - self ._fragment_sync_offset
343
- return step_to_prepare % self ._sync_every == 0
344
-
345
- def should_sync_fragment (self , step : int ) -> bool :
346
- """
347
- Determines if the fragment should be synchronized with other replicas
348
- """
349
- step_to_sync = step - self ._fragment_sync_offset - self ._fragment_sync_delay
350
- return step_to_sync % self ._sync_every == 0
351
-
352
338
@torch .profiler .record_function ("torchft::local_sgd::prepare_sync" )
353
339
def prepare_sync (self ) -> None :
354
340
"""
@@ -384,27 +370,6 @@ def perform_sync(self) -> bool:
384
370
steps using the outer optimizer.
385
371
"""
386
372
# Waiting for an allreduce before it has been sent is currently not supported.
387
- # Please make sure to not do this to avoid running into inconsistencies.
388
- #
389
- # This can happen when using large values of `fragment_sync_delay`.
390
- # The node might not have participated in syncing of this fragment.
391
- #
392
- # The allreduce for other nodes who did might actually
393
- # succeed and in that case, we shouldn't allow recovery
394
- # from this node.
395
- #
396
- # We do need to increase the `max_step` here so we
397
- # don't end up in an infinite loop of needing to recover
398
- # but we can't let other nodes recover from this node
399
- # because it doesn't have the latest state.
400
- #
401
- # We can add a `is_catching_up` flag to the state_dict
402
- # to disallow recoveries from this node. Such nodes can
403
- # be excluded from `max_step` calculation unless all
404
- # nodes are catching up. This approach makes the replica state
405
- # of global parameters diverge though. So we could add recovery
406
- # for a particular fragment from a peer node as a part of the
407
- # `should_commit` or next `quorum` when a node is catching up.
408
373
assert len (self ._allreduce_futures ) > 0
409
374
410
375
self .wait ()
@@ -588,7 +553,11 @@ def __init__(
588
553
if sync_every < len (model_fragments ):
589
554
raise ValueError ("Only 1 fragment can be syncrhonized at a time" )
590
555
591
- if fragment_sync_delay >= sync_every :
556
+ if sync_every % len (model_fragments ) != 0 :
557
+ raise ValueError ("sync_every must divide the number of fragments" )
558
+
559
+ self ._sync_every : int = sync_every // len (model_fragments )
560
+ if fragment_sync_delay >= self ._sync_every :
592
561
raise ValueError (
593
562
"Fragment must be synced before it is reduced another time"
594
563
)
@@ -599,23 +568,12 @@ def __init__(
599
568
super ().__init__ ()
600
569
self ._manager = manager
601
570
602
- # Protects `_local_step`
603
- self ._lock = threading .Lock ()
604
-
605
571
# The number of training iterations performed.
606
572
# Used to synchronize which fragment to send across all
607
573
# replicas
608
574
self ._local_step = 0
609
575
610
- # Sync `_local_step` with other replicas
611
- self ._manager .register_state_dict_fn (
612
- "local_step" ,
613
- self ._load_step ,
614
- lambda : self ._local_step ,
615
- )
616
-
617
- # Used to perform quorum before any training happens
618
- self ._should_recover = True
576
+ self ._fragment_sync_delay = fragment_sync_delay
619
577
620
578
self ._hooks : List [RemovableHandle ] = []
621
579
@@ -648,16 +606,9 @@ def __init__(
648
606
# `_StreamingDiLoCoFragment` about the fragment sync schedule.
649
607
assert fragment_sync_delay < sync_every // len (model_fragments )
650
608
651
- # Used to ensure that we try to sync a fragment after we've sent a prepare for it
652
- self ._first_prepare_sent : set [int ] = set ()
653
-
654
609
# Need to copy the parameters to the host to be safe if we are on the first step.
655
610
self ._save_parameters ()
656
611
657
- def _load_step (self , step : int ) -> None :
658
- with self ._lock :
659
- self ._local_step = step
660
-
661
612
def _save_parameters (self ) -> None :
662
613
for fragment in self ._fragments :
663
614
fragment .save_parameters ()
@@ -694,32 +645,19 @@ def _wait(self) -> None:
694
645
for fragment in self ._fragments :
695
646
fragment .wait ()
696
647
697
- self ._first_prepare_sent .clear ()
698
-
699
- def _quorum_loop (self ) -> None :
648
+ def _current_fragment (self ) -> int :
700
649
"""
701
- Performs infinite retries until quorum is successfull
650
+ Determines which fragment to prepare/sync based on the current step.
702
651
"""
703
- while True :
704
- self ._manager .start_quorum ()
705
-
706
- if self ._manager .errored () is None :
707
- return
652
+ step = self ._manager .current_step ()
653
+ return step % len (self ._fragments )
708
654
709
655
def _step_post_hook (
710
656
self , _optim : optim .Optimizer , _args : Tuple [Any , ...], _kwargs : Dict [str , Any ]
711
657
) -> None :
712
658
"""
713
659
This hook is registered on the optimizer and is called after the optimizer step.
714
660
"""
715
- if self ._should_recover :
716
- # Get the correct step when. This will continue after other committed.
717
- self ._quorum_loop ()
718
- self ._should_recover = False
719
- # This is to be consistent with the nodes that are not recovering. They
720
- # proceed with the code below on the step after quorum completes.
721
- return
722
-
723
661
# We need to make sure all nodes send the same fragments in order.
724
662
# This is to avoid deadlocking e.g.
725
663
#
@@ -730,91 +668,35 @@ def _step_post_hook(
730
668
#
731
669
# Both of them will fail because Node A didn't send fragment 2
732
670
# and Node B didn't send fragment 1.
733
- with self ._lock :
734
- self ._local_step += 1
735
- step = self ._local_step
736
-
737
- # Start sending fragments
738
- for i , fragment in enumerate (self ._fragments ):
739
- if not fragment .should_prepare_fragment (step ):
740
- continue
741
-
742
- logger .debug (f"preparing fragment { i } at step { step } " )
743
-
744
- self ._first_prepare_sent .add (i )
745
- fragment .prepare_sync ()
746
-
747
- for i , fragment in enumerate (self ._fragments ):
748
- if not fragment .should_sync_fragment (step ):
749
- continue
750
-
751
- # We need to have sent an allreduce before we can syncing
752
- # a fragment
753
- if i not in self ._first_prepare_sent :
754
- continue
755
-
756
- logger .debug (f"syncing fragment { i } at step { step } " )
757
-
758
- if not fragment .perform_sync ():
759
- # Cancel all the previously scheduled allreduce by simply
760
- # waiting for them. They should have failed but lets be
761
- # paranoid anyway.
762
- #
763
- # We could choose to resend the failed fragments but that is
764
- # more complicated since it involves coordinating all nodes to
765
- # rewind and resend the fragments.
766
- self ._wait ()
767
-
768
- # Reset the local step. This is needed in case manager `should_commit` fails.
769
- #
770
- # This is because there can be a node that has the same `max_step` as the
771
- # nodes that reached the commit point. However, this node failed before
772
- # it could reach the commit point. So the local steps for these two nodes
773
- # are not the same. But either can be used for recovery.
774
- #
775
- # To make sure both return the same step, we just reset the step to 0
776
- # and start from scratch.
777
- #
778
- # In the happy path, we don't need to reset the step because --
779
- # Nodes participating in the commit bumped their `max_step`.
780
- # Any new nodes will take `local_step` from one of these nodes, which must
781
- # be the same across all nodes because they took the same number of steps
782
- # since the last commit to get to the most recent commit.
783
- with self ._lock :
784
- self ._local_step = 0
785
-
786
- # Avoid doing allreduce after quorum failed.
787
- #
788
- # Maybe a different quorum formed without this node, so this node
789
- # will incorrectly try to allreduce potentially on an incorrect
790
- # fragment because the local_step is also out of sync.
791
- # The replica will need recovery later anyway.
792
- #
793
- # So in case it didn't crash (e.g. network errors), we can save some
794
- # training data by looping here. Otherwise that training data goes to
795
- # waste after recovery
796
- self ._quorum_loop ()
797
-
798
- # TODO: Since we do quorum after commit, there might be a big gap until
799
- # the next allreduce. This increases the chances of nodes failing
800
- # and so the allreduce to fail.
801
- # - We could maybe do a quorum again right before preparing for a fragment
802
- # using `shrink_only`. This might make it tricky for new nodes to join
803
- # though.
804
- # - Maintain a sequence number in the state dict that gets bumped at every
805
- # quorum call. Then we can do a quorum right before allreduce and avoid
806
- # doing quorums after commit.
807
-
808
- # We need to set make sure `_local_step` is still
809
- # the same across all replicas if `quorum_id` changed.
810
- #
811
- # We can't garuntee a majority of replicas in this new quorum
812
- # has the latest `max_step`.
813
- #
814
- # TODO: This is garuntee is currently lacking
815
- # in torchft unless `shrink_only` is set.
671
+ self ._local_step += 1
672
+
673
+ if self ._local_step == self ._sync_every - self ._fragment_sync_delay :
674
+ # Time to prepare a fragment
816
675
#
817
- # After the quorum though, everyone will have the same
818
- # `local_step` because replicas with the chosen
819
- # `max_step` will have the same `local_step`. That is
820
- # because we don't take additional steps after commit.
676
+ # Some replicas will get the same copy of the model, implying batches
677
+ # can be overrepresented.
678
+ self ._manager .start_quorum ()
679
+ fragment = self ._current_fragment ()
680
+ print ("<DEBUG>" )
681
+ print (fragment )
682
+ print (len (self ._fragments ))
683
+ self ._fragments [fragment ].prepare_sync ()
684
+
685
+ if self ._local_step < self ._sync_every :
686
+ return
687
+
688
+ if self ._local_step == self ._sync_every :
689
+ # Time to sync a fragment
690
+ fragment = self ._current_fragment ()
691
+ self ._fragments [fragment ].perform_sync ()
692
+
693
+ # If the allreduce truly failed, we'll keep retrying this fragment.
694
+ # We reset the parameters upon failure. We'll skip over some data
695
+ # but we won't over train before syncing.
696
+
697
+ self ._local_step = 0
698
+ return
699
+
700
+ assert (
701
+ False
702
+ ), f"{ self ._local_step = } should never be greater than { self ._sync_every = } "
0 commit comments