diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 37602f77..acca7991 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -14,6 +14,7 @@ import torch from torch import nn, optim +from torch.distributed.tensor import DTensor from torch.nn.parameter import Parameter from torch.optim.optimizer import Optimizer from torch.utils.hooks import RemovableHandle @@ -23,6 +24,20 @@ logger: logging.Logger = logging.getLogger(__name__) +def extract_local_tensor(t: torch.Tensor) -> torch.Tensor: + """ + Returns a cloned version of the input tensor. If the input tensor is a DTensor, + it extracts and clones its local representation. + """ + new_tensor = None + if isinstance(t, DTensor): + new_tensor = t.to_local().clone() + else: + new_tensor = t.clone() + new_tensor.grad = None + return new_tensor + + class LocalSGD: """ LocalSGD is a context manager that @@ -126,7 +141,15 @@ def _perform_sync(self) -> None: if self._manager.should_commit(): # Update the model parameters with the averaged values for param, avg_param in zip(self._model.parameters(), averaged_parameters): - param.data.copy_(avg_param) + if isinstance(param, DTensor): + # we averaged the local version of the tensor so need to copy it back as a DTensor + param.data.copy_( + DTensor.from_local( + avg_param, param.device_mesh, param.placements + ) + ) + else: + param.data.copy_(avg_param) def _average(self) -> list[torch.Tensor]: """ @@ -136,8 +159,7 @@ def _average(self) -> list[torch.Tensor]: averaged_parameters = [] for p in self._model.parameters(): # Create a new tensor to store the averaged parameter - p.data.grad = None - avg_param = p.data.clone() + avg_param = extract_local_tensor(p) works.append(self._manager.allreduce(avg_param)) averaged_parameters.append(avg_param) for work in works: @@ -182,6 +204,8 @@ def __init__( self._outer_optimizer = outer_optimizer self.original_parameters: Dict[str, torch.Tensor] = {} for name, p in self._model.named_parameters(): + if isinstance(p, DTensor): + p = extract_local_tensor(p.data) t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=self._backup_device) if ( self._pin_memory @@ -198,13 +222,23 @@ def _save_parameters(self) -> None: with torch.no_grad(): # TODO: consider running copy on a separate stream for name, p in self._model.named_parameters(): - self.original_parameters[name].copy_(p.data, non_blocking=True) + param_to_local = extract_local_tensor(p.data) + self.original_parameters[name].copy_(param_to_local, non_blocking=True) def _restore_parameters(self) -> None: with torch.no_grad(): # TODO: consider running copy on a separate stream for name, p in self._model.named_parameters(): - p.data.copy_(self.original_parameters[name], non_blocking=False) + if isinstance(p, DTensor): + # we averaged the local version of the tensor so need to copy it back as a DTensor + p.data.copy_( + DTensor.from_local( + self.original_parameters[name], p.device_mesh, p.placements + ), + non_blocking=False, + ) + else: + p.data.copy_(self.original_parameters[name], non_blocking=False) def __enter__(self) -> "DiLoCo": # Add optimizer hook which increments the local step counter and syncs if necessary @@ -252,8 +286,12 @@ def _perform_sync(self) -> None: """ # Set the .grad field of each parameter to its pseudogradient for name, p in self._model.named_parameters(): - pseudogradient = p.data - self.original_parameters[name] - p.grad = pseudogradient + local_param = extract_local_tensor(p.data) + pseudogradient = local_param - self.original_parameters[name] + if isinstance(p, DTensor): + p.grad._local_tensor = pseudogradient + else: + p.grad = pseudogradient self._average_grads() # Restore the parameters back to the previous state @@ -272,7 +310,10 @@ def _average_grads(self) -> None: for p in self._model.parameters(): # Perform allreduce on the pseudogradients assert p.grad is not None - work = self._manager.allreduce(p.grad) + if isinstance(p, DTensor): + work = self._manager.allreduce(p.grad._local_tensor) + else: + work = self._manager.allreduce(p.grad) works.append(work) # Wait for all allreduce operations to complete for work in works: diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index d26b316b..5c3e67b9 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -6,12 +6,13 @@ from typing import Dict from unittest import TestCase -from unittest.mock import create_autospec +from unittest.mock import MagicMock, create_autospec import torch from torch import nn, optim +from torch.distributed.tensor import DTensor -from torchft.local_sgd import DiLoCo, LocalSGD +from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor from torchft.manager import Manager @@ -62,6 +63,23 @@ def test_local_sgd_healthy(self) -> None: self.assertEqual(manager.should_commit.call_count, 1) self.assertEqual(manager.allreduce.call_count, 4) + def test_extract_local_tensor(self) -> None: + regular_tensor = torch.rand(3, 3, requires_grad=True) + regular_result = extract_local_tensor(regular_tensor) + + self.assertTrue(torch.equal(regular_result, regular_tensor)) + self.assertIsNone(regular_result.grad) + self.assertNotEqual(id(regular_result), id(regular_tensor)) + local_tensor = torch.rand(3, 3, requires_grad=True) + dtensor = MagicMock(spec=DTensor) + dtensor.to_local.return_value = local_tensor + dtensor_result = extract_local_tensor(dtensor) + + self.assertTrue(torch.equal(dtensor_result, local_tensor)) + self.assertIsNone(dtensor_result.grad) + self.assertNotEqual(id(dtensor_result), id(local_tensor)) + dtensor.to_local.assert_called_once() + def test_local_sgd_recovery(self) -> None: model = SimpleModel() optimizer = optim.SGD(model.parameters())