diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 2b0e065..a6c39cd 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -33,38 +33,11 @@ ProcessGroupBabyNCCL, ProcessGroupGloo, ) +from torchft.test.diloco_trainer import DiLoCoTrainer, MultiMyModel logger: logging.Logger = logging.getLogger(__name__) -class MultiMyModel(torch.nn.Module): - def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None: - super().__init__() - self.in_dim = in_dim - - self.layers = torch.nn.ModuleList() - for i in range(n_layers): - self.layers.append(MyModel(in_dim, out_dim)) - in_dim, out_dim = out_dim, in_dim - - self.out_dim = in_dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for layer in self.layers: - x = layer(x) - return x - - def get_rand_inputs( - self, batch_size: int, device: torch.device = torch.device("cpu") - ) -> torch.Tensor: - return torch.rand(batch_size, self.in_dim, device=device) - - def get_rand_labels( - self, batch_size: int, device: torch.device = torch.device("cpu") - ) -> torch.Tensor: - return torch.randint(self.out_dim, (batch_size,), device=device) - - def local_sgd_train_loop( rank: int, store_port: int, @@ -148,158 +121,11 @@ def diloco_train_loop( diloco_args = train_loop_args.get("diloco_args", {}) with ExitStack() as stack: - # Declare the model and optimizers - m = MultiMyModel(2, 3, n_fragments) - m.load_state_dict(model_state_dict) - m.to(device) - - # Setup optimizers - inner_optimizer: optim.Optimizer = torch.optim.AdamW( - m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + trainer = DiLoCoTrainer( + rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args ) - - # Create one outer optimizer per fragment - outer_optimizers = [] - for _, layer in enumerate(m.layers): - outer_optimizers.append( - torch.optim.SGD(layer.parameters(), lr=0.7, momentum=0.9, nesterov=True) - ) - - # pyre-ignore[53] - def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None: - m.load_state_dict(state_dict["model"]) - m.to(device) - - # Load original parameters for each fragment - for i, fragment in enumerate(diloco._fragments): - fragment.original_parameters = cast( - Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"] - ) - - for fragment in diloco._fragments: - for name in fragment.original_parameters.keys(): - fragment.original_parameters[name] = fragment.original_parameters[ - name - ].to(device) - - inner_optimizer.load_state_dict(state_dict["inner_optim"]) - for i, optimizer in enumerate(outer_optimizers): - optimizer.load_state_dict(state_dict[f"outer_optim"][f"{i}"]) - - def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] - return { - "model": m.state_dict(), - "original_params": { - f"{i}": fragment.original_parameters - for i, fragment in enumerate(diloco._fragments) - }, - "inner_optim": inner_optimizer.state_dict(), - "outer_optim": { - f"{i}": optimizer.state_dict() - for i, optimizer in enumerate(outer_optimizers) - }, - } - - print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") - - if device.type == "cuda": - pg = FakeProcessGroupWrapper(ProcessGroupBabyNCCL()) - else: - pg = FakeProcessGroupWrapper( - ProcessGroupGloo(timeout=timedelta(seconds=10)) - ) - manager = Manager( - pg=pg, - min_replica_size=2, - use_async_quorum=False, - load_state_dict=load_state_dict, - state_dict=state_dict, - replica_id=str(runner.replica_id), - store_addr="localhost", - store_port=store_port, - rank=rank, - world_size=runner.world_size, - lighthouse_addr=runner.lighthouse_address, - port=19530 + runner.replica_id, - connect_timeout=timedelta(seconds=10), - quorum_timeout=timedelta(seconds=10), - timeout=timedelta(seconds=10), - # pyre-fixme[6]: Incompatible parameter type - **runner.manager_args, - ) - runner.event_injector.set_pg(pg) - stack.callback(manager.shutdown) - # initialize default group for device mesh to work - if not torch.distributed.is_initialized(): - # TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0 - try: - torch.distributed.init_process_group( - init_method="tcp://localhost:0", - rank=rank, - world_size=runner.world_size, - ) - except ValueError: - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "0" - os.environ["WORLD_SIZE"] = str(runner.world_size) - os.environ["RANK"] = str(rank) - - device_type = device.type - ft_device_mesh = ft_init_device_mesh( - device_type=device_type, - mesh_shape=(runner.world_size, 1), - mesh_dim_names=("replicate", "none"), - replicate_dim=0, - manager=manager, - ) - for layer in m.layers: - if isinstance(layer, nn.Linear): - for param in layer.parameters(): - param = DTensor.from_local( - param, - device_mesh=ft_device_mesh, - ) - - criterion = nn.CrossEntropyLoss() - all_state_dicts = {} - - if "sync_every" not in diloco_args: - diloco_args["sync_every"] = 2 - - with DiLoCo( - manager, - [layer for layer in m.layers], - inner_optimizer, - outer_optimizers, - backup_device=device, - **diloco_args, - ) as diloco: - while True: - runner.event_injector.check(rank, manager.current_step()) - - manager_curr_step = manager.current_step() - if manager_curr_step not in all_state_dicts: - all_state_dicts[manager_curr_step] = copy.deepcopy( - manager._manager_state_dict() - ) - - batch_size = 1 - inputs = m.get_rand_inputs(batch_size, device=device) - labels = m.get_rand_labels(batch_size, device=device) - - out = m(inputs) - loss = criterion(out, labels) - - inner_optimizer.zero_grad() - loss.backward() - inner_optimizer.step() - - # after 4 model updates then break - if manager.current_step() >= 4: - break - - # return state_dict so we can check consistency - return all_state_dicts + stack.callback(trainer.manager.shutdown) + return trainer.train_loop() return {} diff --git a/torchft/test/diloco_trainer.py b/torchft/test/diloco_trainer.py new file mode 100644 index 0000000..f658417 --- /dev/null +++ b/torchft/test/diloco_trainer.py @@ -0,0 +1,310 @@ +import copy +import logging +import os +from contextlib import ExitStack +from datetime import timedelta +from typing import Any, Dict, List, cast + +import torch +from torch import nn +from torch.distributed.tensor import DTensor + +from torchft.device_mesh import ManagedDeviceMesh, ft_init_device_mesh +from torchft.local_sgd import DiLoCo +from torchft.manager import Manager +from torchft.manager_integ_test import MyModel, Runner +from torchft.process_group import ( + FakeProcessGroupWrapper, + ProcessGroupBabyNCCL, + ProcessGroupGloo, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +class MultiModel(torch.nn.Module): + def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None: + super().__init__() + self.layers = torch.nn.ModuleList() + + def get_rand_inputs( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + raise + + def get_rand_labels( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + raise + + +class MultiMyModel(MultiModel): + def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None: + super().__init__() + self.in_dim = in_dim + + for _ in range(n_layers): + self.layers.append(MyModel(in_dim, out_dim)) + in_dim, out_dim = out_dim, in_dim + + self.out_dim = in_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = layer(x) + return x + + def get_rand_inputs( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + return torch.rand(batch_size, self.in_dim, device=device) + + def get_rand_labels( + self, batch_size: int, device: torch.device = torch.device("cpu") + ) -> torch.Tensor: + return torch.randint(self.out_dim, (batch_size,), device=device) + + +class DiLoCoTrainer: + """ + A class that encapsulates the DiLoCo training process. + """ + + def __init__( + self, + rank: int, + store_port: int, + device: torch.device, + runner: Runner, + model_state_dict: dict[str, Any], + n_fragments: int, + diloco_args: dict[str, Any], + ) -> None: + """ + Initialize the DiLoCoTrainer. + + Args: + rank: The rank of the current process. + store_port: The port for the store. + device: The device to use for training. + runner: The runner instance. + train_loop_args: Additional arguments for the training loop. + """ + self.rank: int = rank + self.store_port: int = store_port + self.device: torch.device = device + self.runner: Runner = runner + + # Extract arguments from train_loop_args + self.model_state_dict: Dict[str, Any] = model_state_dict + self.n_fragments: int = n_fragments + self.diloco_args: dict[str, Any] = diloco_args + + # Initialize components + self.model: MultiModel = self.setup_model() + self.inner_optimizer: torch.optim.Optimizer = self.setup_inner_optimizer() + self.outer_optimizers: list[torch.optim.Optimizer] = ( + self.setup_outer_optimizers() + ) + + self.pg: FakeProcessGroupWrapper = self.setup_pg() + # Set up the process group for the event injector + self.runner.event_injector.set_pg(self.pg) + + self.manager: Manager = self.setup_manager() + + self.ft_device_mesh: None | ManagedDeviceMesh = None + self.setup_distributed() + + self.criterion: nn.CrossEntropyLoss = nn.CrossEntropyLoss() + + self.diloco: DiLoCo | None = None + + def setup_model(self) -> MultiModel: + """Set up the model and move it to the device.""" + model = MultiMyModel(2, 3, self.n_fragments) + model.load_state_dict(self.model_state_dict) + model.to(self.device) + return model + + def setup_inner_optimizer(self) -> torch.optim.Optimizer: + """Set up the inner optimizer.""" + return torch.optim.AdamW( + self.model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + ) + + def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]: + """Set up outer optimizers.""" + # Setup inner optimizer + # Create one outer optimizer per fragment + outer_optimizers = [] + for _, layers in enumerate(self.model.layers): + outer_optimizers.append( + torch.optim.SGD( + layers.parameters(), lr=0.7, momentum=0.9, nesterov=True + ) + ) + return outer_optimizers + + def setup_pg(self) -> FakeProcessGroupWrapper: + if self.device.type == "cuda": + return FakeProcessGroupWrapper(ProcessGroupBabyNCCL()) + else: + return FakeProcessGroupWrapper( + ProcessGroupGloo(timeout=timedelta(seconds=10)) + ) + + def setup_manager(self) -> Manager: + """Set up the process group and manager.""" + print( + f"worker {self.runner.replica_id=} {self.rank=} {self.runner.world_size=} starting" + ) + + # Create manager with all arguments passed directly + return Manager( + pg=self.pg, + min_replica_size=2, + use_async_quorum=False, + load_state_dict=self.load_state_dict, + state_dict=self.state_dict, + replica_id=str(self.runner.replica_id), + store_addr="localhost", + store_port=self.store_port, + rank=self.rank, + world_size=self.runner.world_size, + lighthouse_addr=self.runner.lighthouse_address, + port=19530 + self.runner.replica_id, + connect_timeout=timedelta(seconds=10), + quorum_timeout=timedelta(seconds=10), + timeout=timedelta(seconds=10), + **self.runner.manager_args, # type: ignore + ) + + def setup_distributed(self) -> None: + """Set up distributed training.""" + # Initialize default group for device mesh to work + if not torch.distributed.is_initialized(): + # TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0 + try: + torch.distributed.init_process_group( + init_method="tcp://localhost:0", + rank=self.rank, + world_size=self.runner.world_size, + ) + except ValueError: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "0" + os.environ["WORLD_SIZE"] = str(self.runner.world_size) + os.environ["RANK"] = str(self.rank) + + self.ft_device_mesh = ft_init_device_mesh( + device_type=self.device.type, + mesh_shape=(self.runner.world_size, 1), + mesh_dim_names=("replicate", "none"), + replicate_dim=0, + manager=self.manager, + ) + + # Convert model parameters to DTensor + for layer in self.model.layers: + if isinstance(layer, nn.Linear): + for param in layer.parameters(): + param = DTensor.from_local( + param, + device_mesh=self.ft_device_mesh, + ) + + def load_state_dict(self, state_dict: Dict[str, Dict[str, object]]) -> None: + """ + Load the state dictionary. + + Args: + state_dict: The state dictionary to load. + """ + assert self.diloco is not None + + self.model.load_state_dict(state_dict["model"]) + self.model.to(self.device) + + # Load original parameters for each fragment + for i, fragment in enumerate(cast(DiLoCo, self.diloco)._fragments): + fragment.original_parameters = cast( + Dict[str, torch.Tensor], state_dict["original_params"][f"{i}"] + ) + + for fragment in cast(DiLoCo, self.diloco)._fragments: + for name in fragment.original_parameters.keys(): + fragment.original_parameters[name] = fragment.original_parameters[ + name + ].to(self.device) + + self.inner_optimizer.load_state_dict(state_dict["inner_optim"]) + for i, optimizer in enumerate(self.outer_optimizers): + optimizer.load_state_dict( + cast(dict[str, torch.Tensor], state_dict[f"outer_optim"][f"{i}"]) + ) + + def state_dict(self) -> Dict[str, Dict[str, object]]: + """ + Get the state dictionary. + + Returns: + The state dictionary. + """ + assert self.diloco is not None + + return { + "model": self.model.state_dict(), + "original_params": { + f"{i}": fragment.original_parameters + for i, fragment in enumerate(cast(DiLoCo, self.diloco)._fragments) + }, + "inner_optim": self.inner_optimizer.state_dict(), + "outer_optim": { + f"{i}": optimizer.state_dict() + for i, optimizer in enumerate(self.outer_optimizers) + }, + } + + def train_loop(self) -> dict[str, Any]: + """Run the training loop.""" + # Ensure sync_every is set in diloco_args + all_state_dicts = {} + + if "sync_every" not in self.diloco_args: + self.diloco_args["sync_every"] = 2 + + with DiLoCo( + self.manager, + [layer for layer in self.model.layers], + self.inner_optimizer, + self.outer_optimizers, + backup_device=self.device, + **self.diloco_args, + ) as self.diloco: + while True: + self.runner.event_injector.check(self.rank, self.manager.current_step()) + + manager_curr_step = self.manager.current_step() + if manager_curr_step not in all_state_dicts: + # Store the manager state dict, converting to the right type + all_state_dicts[manager_curr_step] = copy.deepcopy( + self.manager._manager_state_dict() + ) + + batch_size = 1 + inputs = self.model.get_rand_inputs(batch_size, device=self.device) + labels = self.model.get_rand_labels(batch_size, device=self.device) + + out = self.model(inputs) + loss = self.criterion(out, labels) + + self.inner_optimizer.zero_grad() + loss.backward() + self.inner_optimizer.step() + + # after 4 model updates then break + if self.manager.current_step() >= 4: + break + + return all_state_dicts