Skip to content

Support DTensor params in local_sgd/diloco #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 49 additions & 8 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
from torch import nn, optim
from torch.distributed.tensor import DTensor
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
from torch.utils.hooks import RemovableHandle
Expand All @@ -23,6 +24,20 @@
logger: logging.Logger = logging.getLogger(__name__)


def extract_local_tensor(t: torch.Tensor) -> torch.Tensor:
"""
Returns a cloned version of the input tensor. If the input tensor is a DTensor,
it extracts and clones its local representation.
"""
new_tensor = None
if isinstance(t, DTensor):
new_tensor = t.to_local().clone()
else:
new_tensor = t.clone()
new_tensor.grad = None
return new_tensor


class LocalSGD:
"""
LocalSGD is a context manager that
Expand Down Expand Up @@ -126,7 +141,15 @@ def _perform_sync(self) -> None:
if self._manager.should_commit():
# Update the model parameters with the averaged values
for param, avg_param in zip(self._model.parameters(), averaged_parameters):
param.data.copy_(avg_param)
if isinstance(param, DTensor):
# we averaged the local version of the tensor so need to copy it back as a DTensor
param.data.copy_(
DTensor.from_local(
avg_param, param.device_mesh, param.placements
)
)
else:
param.data.copy_(avg_param)

def _average(self) -> list[torch.Tensor]:
"""
Expand All @@ -136,8 +159,7 @@ def _average(self) -> list[torch.Tensor]:
averaged_parameters = []
for p in self._model.parameters():
# Create a new tensor to store the averaged parameter
p.data.grad = None
avg_param = p.data.clone()
avg_param = extract_local_tensor(p)
works.append(self._manager.allreduce(avg_param))
averaged_parameters.append(avg_param)
for work in works:
Expand Down Expand Up @@ -182,6 +204,8 @@ def __init__(
self._outer_optimizer = outer_optimizer
self.original_parameters: Dict[str, torch.Tensor] = {}
for name, p in self._model.named_parameters():
if isinstance(p, DTensor):
p = extract_local_tensor(p.data)
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=self._backup_device)
if (
self._pin_memory
Expand All @@ -198,13 +222,23 @@ def _save_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
self.original_parameters[name].copy_(p.data, non_blocking=True)
param_to_local = extract_local_tensor(p.data)
self.original_parameters[name].copy_(param_to_local, non_blocking=True)

def _restore_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.data.copy_(self.original_parameters[name], non_blocking=False)
if isinstance(p, DTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If p is a DTensor, does p.copy_(), instead of p.data.copy_(), work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC it didn't work and causes a segfault

# we averaged the local version of the tensor so need to copy it back as a DTensor
p.data.copy_(
DTensor.from_local(
self.original_parameters[name], p.device_mesh, p.placements
),
non_blocking=False,
)
else:
p.data.copy_(self.original_parameters[name], non_blocking=False)

def __enter__(self) -> "DiLoCo":
# Add optimizer hook which increments the local step counter and syncs if necessary
Expand Down Expand Up @@ -252,8 +286,12 @@ def _perform_sync(self) -> None:
"""
# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model.named_parameters():
pseudogradient = p.data - self.original_parameters[name]
p.grad = pseudogradient
local_param = extract_local_tensor(p.data)
pseudogradient = local_param - self.original_parameters[name]
if isinstance(p, DTensor):
p.grad._local_tensor = pseudogradient
else:
p.grad = pseudogradient

self._average_grads()
# Restore the parameters back to the previous state
Expand All @@ -272,7 +310,10 @@ def _average_grads(self) -> None:
for p in self._model.parameters():
# Perform allreduce on the pseudogradients
assert p.grad is not None
work = self._manager.allreduce(p.grad)
if isinstance(p, DTensor):
work = self._manager.allreduce(p.grad._local_tensor)
else:
work = self._manager.allreduce(p.grad)
works.append(work)
# Wait for all allreduce operations to complete
for work in works:
Expand Down
22 changes: 20 additions & 2 deletions torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

from typing import Dict
from unittest import TestCase
from unittest.mock import create_autospec
from unittest.mock import MagicMock, create_autospec

import torch
from torch import nn, optim
from torch.distributed.tensor import DTensor

from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
from torchft.manager import Manager


Expand Down Expand Up @@ -62,6 +63,23 @@ def test_local_sgd_healthy(self) -> None:
self.assertEqual(manager.should_commit.call_count, 1)
self.assertEqual(manager.allreduce.call_count, 4)

def test_extract_local_tensor(self) -> None:
regular_tensor = torch.rand(3, 3, requires_grad=True)
regular_result = extract_local_tensor(regular_tensor)

self.assertTrue(torch.equal(regular_result, regular_tensor))
self.assertIsNone(regular_result.grad)
self.assertNotEqual(id(regular_result), id(regular_tensor))
local_tensor = torch.rand(3, 3, requires_grad=True)
dtensor = MagicMock(spec=DTensor)
dtensor.to_local.return_value = local_tensor
dtensor_result = extract_local_tensor(dtensor)

self.assertTrue(torch.equal(dtensor_result, local_tensor))
self.assertIsNone(dtensor_result.grad)
self.assertNotEqual(id(dtensor_result), id(local_tensor))
dtensor.to_local.assert_called_once()

def test_local_sgd_recovery(self) -> None:
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
Expand Down
Loading