diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 228b070..5c96fa6 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -281,8 +281,8 @@ def _save_grads(self) -> None: local_param = p.to_local() else: local_param = p - pseudogradient = local_param - self.original_parameters[name].to( - p.device + pseudogradient = ( + self.original_parameters[name].to(p.device) - local_param ) self._grads[name] = pseudogradient @@ -318,7 +318,7 @@ def _merge_parameters(self) -> None: Merges the local and global parameters. """ for name, p in self._model_fragment.named_parameters(): - p.data.lerp(self._local_parameters[name], 1 - self._fragment_update_alpha) + p.data.lerp_(self._local_parameters[name], self._fragment_update_alpha) @torch.profiler.record_function("torchft::local_sgd::wait") def wait(self) -> None: @@ -335,20 +335,6 @@ def wait(self) -> None: self._allreduce_futures = [] - def should_prepare_fragment(self, step: int) -> bool: - """ - Determines if the fragment should be asynchronously sent to other replicas - """ - step_to_prepare = step - self._fragment_sync_offset - return step_to_prepare % self._sync_every == 0 - - def should_sync_fragment(self, step: int) -> bool: - """ - Determines if the fragment should be synchronized with other replicas - """ - step_to_sync = step - self._fragment_sync_offset - self._fragment_sync_delay - return step_to_sync % self._sync_every == 0 - @torch.profiler.record_function("torchft::local_sgd::prepare_sync") def prepare_sync(self) -> None: """ @@ -384,27 +370,6 @@ def perform_sync(self) -> bool: steps using the outer optimizer. """ # Waiting for an allreduce before it has been sent is currently not supported. - # Please make sure to not do this to avoid running into inconsistencies. - # - # This can happen when using large values of `fragment_sync_delay`. - # The node might not have participated in syncing of this fragment. - # - # The allreduce for other nodes who did might actually - # succeed and in that case, we shouldn't allow recovery - # from this node. - # - # We do need to increase the `max_step` here so we - # don't end up in an infinite loop of needing to recover - # but we can't let other nodes recover from this node - # because it doesn't have the latest state. - # - # We can add a `is_catching_up` flag to the state_dict - # to disallow recoveries from this node. Such nodes can - # be excluded from `max_step` calculation unless all - # nodes are catching up. This approach makes the replica state - # of global parameters diverge though. So we could add recovery - # for a particular fragment from a peer node as a part of the - # `should_commit` or next `quorum` when a node is catching up. assert len(self._allreduce_futures) > 0 self.wait() @@ -588,7 +553,11 @@ def __init__( if sync_every < len(model_fragments): raise ValueError("Only 1 fragment can be syncrhonized at a time") - if fragment_sync_delay >= sync_every: + if sync_every % len(model_fragments) != 0: + raise ValueError("sync_every must divide the number of fragments") + + self._sync_every: int = sync_every // len(model_fragments) + if fragment_sync_delay >= self._sync_every: raise ValueError( "Fragment must be synced before it is reduced another time" ) @@ -599,23 +568,12 @@ def __init__( super().__init__() self._manager = manager - # Protects `_local_step` - self._lock = threading.Lock() - # The number of training iterations performed. # Used to synchronize which fragment to send across all # replicas self._local_step = 0 - # Sync `_local_step` with other replicas - self._manager.register_state_dict_fn( - "local_step", - self._load_step, - lambda: self._local_step, - ) - - # Used to perform quorum before any training happens - self._should_recover = True + self._fragment_sync_delay = fragment_sync_delay self._hooks: List[RemovableHandle] = [] @@ -648,16 +606,9 @@ def __init__( # `_StreamingDiLoCoFragment` about the fragment sync schedule. assert fragment_sync_delay < sync_every // len(model_fragments) - # Used to ensure that we try to sync a fragment after we've sent a prepare for it - self._first_prepare_sent: set[int] = set() - # Need to copy the parameters to the host to be safe if we are on the first step. self._save_parameters() - def _load_step(self, step: int) -> None: - with self._lock: - self._local_step = step - def _save_parameters(self) -> None: for fragment in self._fragments: fragment.save_parameters() @@ -694,17 +645,12 @@ def _wait(self) -> None: for fragment in self._fragments: fragment.wait() - self._first_prepare_sent.clear() - - def _quorum_loop(self) -> None: + def _current_fragment(self) -> int: """ - Performs infinite retries until quorum is successfull + Determines which fragment to prepare/sync based on the current step. """ - while True: - self._manager.start_quorum() - - if self._manager.errored() is None: - return + step = self._manager.current_step() + return step % len(self._fragments) def _step_post_hook( self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any] @@ -712,14 +658,6 @@ def _step_post_hook( """ This hook is registered on the optimizer and is called after the optimizer step. """ - if self._should_recover: - # Get the correct step when. This will continue after other committed. - self._quorum_loop() - self._should_recover = False - # This is to be consistent with the nodes that are not recovering. They - # proceed with the code below on the step after quorum completes. - return - # We need to make sure all nodes send the same fragments in order. # This is to avoid deadlocking e.g. # @@ -730,91 +668,32 @@ def _step_post_hook( # # Both of them will fail because Node A didn't send fragment 2 # and Node B didn't send fragment 1. - with self._lock: - self._local_step += 1 - step = self._local_step - - # Start sending fragments - for i, fragment in enumerate(self._fragments): - if not fragment.should_prepare_fragment(step): - continue - - logger.debug(f"preparing fragment {i} at step {step}") - - self._first_prepare_sent.add(i) - fragment.prepare_sync() - - for i, fragment in enumerate(self._fragments): - if not fragment.should_sync_fragment(step): - continue - - # We need to have sent an allreduce before we can syncing - # a fragment - if i not in self._first_prepare_sent: - continue - - logger.debug(f"syncing fragment {i} at step {step}") - - if not fragment.perform_sync(): - # Cancel all the previously scheduled allreduce by simply - # waiting for them. They should have failed but lets be - # paranoid anyway. - # - # We could choose to resend the failed fragments but that is - # more complicated since it involves coordinating all nodes to - # rewind and resend the fragments. - self._wait() - - # Reset the local step. This is needed in case manager `should_commit` fails. - # - # This is because there can be a node that has the same `max_step` as the - # nodes that reached the commit point. However, this node failed before - # it could reach the commit point. So the local steps for these two nodes - # are not the same. But either can be used for recovery. - # - # To make sure both return the same step, we just reset the step to 0 - # and start from scratch. - # - # In the happy path, we don't need to reset the step because -- - # Nodes participating in the commit bumped their `max_step`. - # Any new nodes will take `local_step` from one of these nodes, which must - # be the same across all nodes because they took the same number of steps - # since the last commit to get to the most recent commit. - with self._lock: - self._local_step = 0 - - # Avoid doing allreduce after quorum failed. - # - # Maybe a different quorum formed without this node, so this node - # will incorrectly try to allreduce potentially on an incorrect - # fragment because the local_step is also out of sync. - # The replica will need recovery later anyway. - # - # So in case it didn't crash (e.g. network errors), we can save some - # training data by looping here. Otherwise that training data goes to - # waste after recovery - self._quorum_loop() - - # TODO: Since we do quorum after commit, there might be a big gap until - # the next allreduce. This increases the chances of nodes failing - # and so the allreduce to fail. - # - We could maybe do a quorum again right before preparing for a fragment - # using `shrink_only`. This might make it tricky for new nodes to join - # though. - # - Maintain a sequence number in the state dict that gets bumped at every - # quorum call. Then we can do a quorum right before allreduce and avoid - # doing quorums after commit. - - # We need to set make sure `_local_step` is still - # the same across all replicas if `quorum_id` changed. - # - # We can't garuntee a majority of replicas in this new quorum - # has the latest `max_step`. - # - # TODO: This is garuntee is currently lacking - # in torchft unless `shrink_only` is set. + self._local_step += 1 + + if self._local_step == self._sync_every - self._fragment_sync_delay: + # Time to prepare a fragment # - # After the quorum though, everyone will have the same - # `local_step` because replicas with the chosen - # `max_step` will have the same `local_step`. That is - # because we don't take additional steps after commit. + # Some replicas will get the same copy of the model, implying batches + # can be overrepresented. + self._manager.start_quorum() + fragment = self._current_fragment() + self._fragments[fragment].prepare_sync() + + if self._local_step < self._sync_every: + return + + if self._local_step == self._sync_every: + # Time to sync a fragment + fragment = self._current_fragment() + self._fragments[fragment].perform_sync() + + # If the allreduce truly failed, we'll keep retrying this fragment. + # We reset the parameters upon failure. We'll skip over some data + # but we won't over train before syncing. + + self._local_step = 0 + return + + assert ( + False + ), f"{self._local_step=} should never be greater than {self._sync_every=}" diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 2b0e065..70560f2 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -33,38 +33,11 @@ ProcessGroupBabyNCCL, ProcessGroupGloo, ) +from torchft.test.diloco_trainer import DiLoCoTrainer, MultiMyModel logger: logging.Logger = logging.getLogger(__name__) -class MultiMyModel(torch.nn.Module): - def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None: - super().__init__() - self.in_dim = in_dim - - self.layers = torch.nn.ModuleList() - for i in range(n_layers): - self.layers.append(MyModel(in_dim, out_dim)) - in_dim, out_dim = out_dim, in_dim - - self.out_dim = in_dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for layer in self.layers: - x = layer(x) - return x - - def get_rand_inputs( - self, batch_size: int, device: torch.device = torch.device("cpu") - ) -> torch.Tensor: - return torch.rand(batch_size, self.in_dim, device=device) - - def get_rand_labels( - self, batch_size: int, device: torch.device = torch.device("cpu") - ) -> torch.Tensor: - return torch.randint(self.out_dim, (batch_size,), device=device) - - def local_sgd_train_loop( rank: int, store_port: int, @@ -148,158 +121,11 @@ def diloco_train_loop( diloco_args = train_loop_args.get("diloco_args", {}) with ExitStack() as stack: - # Declare the model and optimizers - m = MultiMyModel(2, 3, n_fragments) - m.load_state_dict(model_state_dict) - m.to(device) - - # Setup optimizers - inner_optimizer: optim.Optimizer = torch.optim.AdamW( - m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + trainer = DiLoCoTrainer( + rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args ) - - # Create one outer optimizer per fragment - outer_optimizers = [] - for _, layer in enumerate(m.layers): - outer_optimizers.append( - torch.optim.SGD(layer.parameters(), lr=0.7, momentum=0.9, nesterov=True) - ) - - # pyre-ignore[53] - def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: - m.load_state_dict(state_dict["model"]) - m.to(device) - - # Load original parameters for each fragment - for i, fragment in enumerate(diloco._fragments): - fragment.original_parameters = cast( - Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"] - ) - - for fragment in diloco._fragments: - for name in fragment.original_parameters.keys(): - fragment.original_parameters[name] = fragment.original_parameters[ - name - ].to(device) - - inner_optimizer.load_state_dict(state_dict["inner_optim"]) - for i, optimizer in enumerate(outer_optimizers): - optimizer.load_state_dict(state_dict[f"outer_optim"][f"{i}"]) - - def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] - return { - "model": m.state_dict(), - "original_params": { - f"{i}": fragment.original_parameters - for i, fragment in enumerate(diloco._fragments) - }, - "inner_optim": inner_optimizer.state_dict(), - "outer_optim": { - f"{i}": optimizer.state_dict() - for i, optimizer in enumerate(outer_optimizers) - }, - } - - print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") - - if device.type == "cuda": - pg = FakeProcessGroupWrapper(ProcessGroupBabyNCCL()) - else: - pg = FakeProcessGroupWrapper( - ProcessGroupGloo(timeout=timedelta(seconds=10)) - ) - manager = Manager( - pg=pg, - min_replica_size=2, - use_async_quorum=False, - load_state_dict=load_state_dict, - state_dict=state_dict, - replica_id=str(runner.replica_id), - store_addr="localhost", - store_port=store_port, - rank=rank, - world_size=runner.world_size, - lighthouse_addr=runner.lighthouse_address, - port=19530 + runner.replica_id, - connect_timeout=timedelta(seconds=10), - quorum_timeout=timedelta(seconds=10), - timeout=timedelta(seconds=10), - # pyre-fixme[6]: Incompatible parameter type - **runner.manager_args, - ) - runner.event_injector.set_pg(pg) - stack.callback(manager.shutdown) - # initialize default group for device mesh to work - if not torch.distributed.is_initialized(): - # TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0 - try: - torch.distributed.init_process_group( - init_method="tcp://localhost:0", - rank=rank, - world_size=runner.world_size, - ) - except ValueError: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "0" - os.environ["WORLD_SIZE"] = str(runner.world_size) - os.environ["RANK"] = str(rank) - - device_type = device.type - ft_device_mesh = ft_init_device_mesh( - device_type=device_type, - mesh_shape=(runner.world_size, 1), - mesh_dim_names=("replicate", "none"), - replicate_dim=0, - manager=manager, - ) - for layer in m.layers: - if isinstance(layer, nn.Linear): - for param in layer.parameters(): - param = DTensor.from_local( - param, - device_mesh=ft_device_mesh, - ) - - criterion = nn.CrossEntropyLoss() - all_state_dicts = {} - - if "sync_every" not in diloco_args: - diloco_args["sync_every"] = 2 - - with DiLoCo( - manager, - [layer for layer in m.layers], - inner_optimizer, - outer_optimizers, - backup_device=device, - **diloco_args, - ) as diloco: - while True: - runner.event_injector.check(rank, manager.current_step()) - - manager_curr_step = manager.current_step() - if manager_curr_step not in all_state_dicts: - all_state_dicts[manager_curr_step] = copy.deepcopy( - manager._manager_state_dict() - ) - - batch_size = 1 - inputs = m.get_rand_inputs(batch_size, device=device) - labels = m.get_rand_labels(batch_size, device=device) - - out = m(inputs) - loss = criterion(out, labels) - - inner_optimizer.zero_grad() - loss.backward() - inner_optimizer.step() - - # after 4 model updates then break - if manager.current_step() >= 4: - break - - # return state_dict so we can check consistency - return all_state_dicts + stack.callback(trainer.manager.shutdown) + return trainer.train_loop() return {} @@ -592,16 +418,6 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None: assert_equal_global_state(rep1, rep0) - for step in rep1.keys(): - if step == 2: - # Replica 0 should have reset its `local_step` after failure - self.assertEqual(rep1[step]["user"]["local_step"], 0) - self.assertEqual(rep0[step]["user"]["local_step"], 5) - else: - self.assertEqual( - rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"] - ) - self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1) CONFIG: list[tuple[bool, int, int, float]] = [ @@ -683,14 +499,6 @@ def test_streaming_diloco_upscale( assert_equal_global_state(rep0, rep1) assert_equal_global_state(rep0, rep2) - for step in rep0.keys(): - self.assertEqual( - rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"] - ) - self.assertEqual( - rep1[step]["user"]["local_step"], rep2[step]["user"]["local_step"] - ) - for event_injector in event_injectors: self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1) @@ -760,11 +568,6 @@ def test_streaming_diloco_commit_failure( assert_equal_global_state(rep0, rep1) - for step in rep0.keys(): - self.assertEqual( - rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"] - ) - for event_injector in event_injectors: self.assertEqual( event_injector.count[EventInjectorEvent.AllreduceFailure], 1 diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index d1dd183..04aede4 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -157,20 +157,15 @@ def test_diloco_healthy(self) -> None: loss.backward() inner_optimizer.step() - self.assertEqual(diloco._local_step, 0) - loss = model(inp).mean() - loss.backward() - inner_optimizer.step() - self.assertEqual(diloco._local_step, 1) - self.assertEqual(manager.start_quorum.call_count, 1) + manager.current_step.return_value = 0 + manager.should_commit.return_value = True loss = model(inp).mean() loss.backward() inner_optimizer.step() - self.assertEqual(manager.start_quorum.call_count, 2) - manager.should_commit.return_value = True - self.assertEqual(diloco._local_step, 2) + self.assertEqual(diloco._local_step, 0) + self.assertEqual(manager.start_quorum.call_count, 1) torch.testing.assert_close( diloco._fragments[0].original_parameters, _params_dict(model) ) @@ -218,6 +213,7 @@ def test_diloco_allreduce_call_efficiency( loss.backward() inner_optimizer.step() + manager.current_step.return_value = 0 loss = model(inp).mean() loss.backward() inner_optimizer.step() @@ -320,8 +316,8 @@ def fake_allreduce( diloco._fragments[0]._set_grads() # we added 2 to the parameters, then multiplied the gradients by 2 - # so we should expect the model's gradient to be 4 - expected_grad = 4 + # so we should expect the model's gradient to be -4 + expected_grad = -4 for param in model.parameters(): assert param.grad is not None t = torch.empty_like(param.grad) diff --git a/torchft/test/diloco_trainer.py b/torchft/test/diloco_trainer.py new file mode 100644 index 0000000..f658417 --- /dev/null +++ b/torchft/test/diloco_trainer.py @@ -0,0 +1,310 @@ +import copy +import logging +import os +from contextlib import ExitStack +from datetime import timedelta +from typing import Any, Dict, List, cast + +import torch +from torch import nn +from torch.distributed.tensor import DTensor + +from torchft.device_mesh import ManagedDeviceMesh, ft_init_device_mesh +from torchft.local_sgd import DiLoCo +from torchft.manager import Manager +from torchft.manager_integ_test import MyModel, Runner +from torchft.process_group import ( + FakeProcessGroupWrapper, + ProcessGroupBabyNCCL, + ProcessGroupGloo, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +class MultiModel(torch.nn.Module): + def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None: + super().__init__() + self.layers = torch.nn.ModuleList() + + def get_rand_inputs( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + raise + + def get_rand_labels( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + raise + + +class MultiMyModel(MultiModel): + def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None: + super().__init__() + self.in_dim = in_dim + + for _ in range(n_layers): + self.layers.append(MyModel(in_dim, out_dim)) + in_dim, out_dim = out_dim, in_dim + + self.out_dim = in_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = layer(x) + return x + + def get_rand_inputs( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + return torch.rand(batch_size, self.in_dim, device=device) + + def get_rand_labels( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + return torch.randint(self.out_dim, (batch_size,), device=device) + + +class DiLoCoTrainer: + """ + A class that encapsulates the DiLoCo training process. + """ + + def __init__( + self, + rank: int, + store_port: int, + device: torch.device, + runner: Runner, + model_state_dict: dict[str, Any], + n_fragments: int, + diloco_args: dict[str, Any], + ) -> None: + """ + Initialize the DiLoCoTrainer. + + Args: + rank: The rank of the current process. + store_port: The port for the store. + device: The device to use for training. + runner: The runner instance. + train_loop_args: Additional arguments for the training loop. + """ + self.rank: int = rank + self.store_port: int = store_port + self.device: torch.device = device + self.runner: Runner = runner + + # Extract arguments from train_loop_args + self.model_state_dict: Dict[str, Any] = model_state_dict + self.n_fragments: int = n_fragments + self.diloco_args: dict[str, Any] = diloco_args + + # Initialize components + self.model: MultiModel = self.setup_model() + self.inner_optimizer: torch.optim.Optimizer = self.setup_inner_optimizer() + self.outer_optimizers: list[torch.optim.Optimizer] = ( + self.setup_outer_optimizers() + ) + + self.pg: FakeProcessGroupWrapper = self.setup_pg() + # Set up the process group for the event injector + self.runner.event_injector.set_pg(self.pg) + + self.manager: Manager = self.setup_manager() + + self.ft_device_mesh: None | ManagedDeviceMesh = None + self.setup_distributed() + + self.criterion: nn.CrossEntropyLoss = nn.CrossEntropyLoss() + + self.diloco: DiLoCo | None = None + + def setup_model(self) -> MultiModel: + """Set up the model and move it to the device.""" + model = MultiMyModel(2, 3, self.n_fragments) + model.load_state_dict(self.model_state_dict) + model.to(self.device) + return model + + def setup_inner_optimizer(self) -> torch.optim.Optimizer: + """Set up the inner optimizer.""" + return torch.optim.AdamW( + self.model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + ) + + def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]: + """Set up outer optimizers.""" + # Setup inner optimizer + # Create one outer optimizer per fragment + outer_optimizers = [] + for _, layers in enumerate(self.model.layers): + outer_optimizers.append( + torch.optim.SGD( + layers.parameters(), lr=0.7, momentum=0.9, nesterov=True + ) + ) + return outer_optimizers + + def setup_pg(self) -> FakeProcessGroupWrapper: + if self.device.type == "cuda": + return FakeProcessGroupWrapper(ProcessGroupBabyNCCL()) + else: + return FakeProcessGroupWrapper( + ProcessGroupGloo(timeout=timedelta(seconds=10)) + ) + + def setup_manager(self) -> Manager: + """Set up the process group and manager.""" + print( + f"worker {self.runner.replica_id=} {self.rank=} {self.runner.world_size=} starting" + ) + + # Create manager with all arguments passed directly + return Manager( + pg=self.pg, + min_replica_size=2, + use_async_quorum=False, + load_state_dict=self.load_state_dict, + state_dict=self.state_dict, + replica_id=str(self.runner.replica_id), + store_addr="localhost", + store_port=self.store_port, + rank=self.rank, + world_size=self.runner.world_size, + lighthouse_addr=self.runner.lighthouse_address, + port=19530 + self.runner.replica_id, + connect_timeout=timedelta(seconds=10), + quorum_timeout=timedelta(seconds=10), + timeout=timedelta(seconds=10), + **self.runner.manager_args, # type: ignore + ) + + def setup_distributed(self) -> None: + """Set up distributed training.""" + # Initialize default group for device mesh to work + if not torch.distributed.is_initialized(): + # TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0 + try: + torch.distributed.init_process_group( + init_method="tcp://localhost:0", + rank=self.rank, + world_size=self.runner.world_size, + ) + except ValueError: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "0" + os.environ["WORLD_SIZE"] = str(self.runner.world_size) + os.environ["RANK"] = str(self.rank) + + self.ft_device_mesh = ft_init_device_mesh( + device_type=self.device.type, + mesh_shape=(self.runner.world_size, 1), + mesh_dim_names=("replicate", "none"), + replicate_dim=0, + manager=self.manager, + ) + + # Convert model parameters to DTensor + for layer in self.model.layers: + if isinstance(layer, nn.Linear): + for param in layer.parameters(): + param = DTensor.from_local( + param, + device_mesh=self.ft_device_mesh, + ) + + def load_state_dict(self, state_dict: Dict[str, Dict[str, object]]) -> None: + """ + Load the state dictionary. + + Args: + state_dict: The state dictionary to load. + """ + assert self.diloco is not None + + self.model.load_state_dict(state_dict["model"]) + self.model.to(self.device) + + # Load original parameters for each fragment + for i, fragment in enumerate(cast(DiLoCo, self.diloco)._fragments): + fragment.original_parameters = cast( + Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"] + ) + + for fragment in cast(DiLoCo, self.diloco)._fragments: + for name in fragment.original_parameters.keys(): + fragment.original_parameters[name] = fragment.original_parameters[ + name + ].to(self.device) + + self.inner_optimizer.load_state_dict(state_dict["inner_optim"]) + for i, optimizer in enumerate(self.outer_optimizers): + optimizer.load_state_dict( + cast(dict[str, torch.Tensor], state_dict[f"outer_optim"][f"{i}"]) + ) + + def state_dict(self) -> Dict[str, Dict[str, object]]: + """ + Get the state dictionary. + + Returns: + The state dictionary. + """ + assert self.diloco is not None + + return { + "model": self.model.state_dict(), + "original_params": { + f"{i}": fragment.original_parameters + for i, fragment in enumerate(cast(DiLoCo, self.diloco)._fragments) + }, + "inner_optim": self.inner_optimizer.state_dict(), + "outer_optim": { + f"{i}": optimizer.state_dict() + for i, optimizer in enumerate(self.outer_optimizers) + }, + } + + def train_loop(self) -> dict[str, Any]: + """Run the training loop.""" + # Ensure sync_every is set in diloco_args + all_state_dicts = {} + + if "sync_every" not in self.diloco_args: + self.diloco_args["sync_every"] = 2 + + with DiLoCo( + self.manager, + [layer for layer in self.model.layers], + self.inner_optimizer, + self.outer_optimizers, + backup_device=self.device, + **self.diloco_args, + ) as self.diloco: + while True: + self.runner.event_injector.check(self.rank, self.manager.current_step()) + + manager_curr_step = self.manager.current_step() + if manager_curr_step not in all_state_dicts: + # Store the manager state dict, converting to the right type + all_state_dicts[manager_curr_step] = copy.deepcopy( + self.manager._manager_state_dict() + ) + + batch_size = 1 + inputs = self.model.get_rand_inputs(batch_size, device=self.device) + labels = self.model.get_rand_labels(batch_size, device=self.device) + + out = self.model(inputs) + loss = self.criterion(out, labels) + + self.inner_optimizer.zero_grad() + loss.backward() + self.inner_optimizer.step() + + # after 4 model updates then break + if self.manager.current_step() >= 4: + break + + return all_state_dicts