Skip to content

Commit 360c5c5

Browse files
authored
Support DTensor params in local_sgd/diloco (#168)
1 parent 4961e56 commit 360c5c5

File tree

2 files changed

+69
-10
lines changed

2 files changed

+69
-10
lines changed

torchft/local_sgd.py

+49-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from torch import nn, optim
17+
from torch.distributed.tensor import DTensor
1718
from torch.nn.parameter import Parameter
1819
from torch.optim.optimizer import Optimizer
1920
from torch.utils.hooks import RemovableHandle
@@ -23,6 +24,20 @@
2324
logger: logging.Logger = logging.getLogger(__name__)
2425

2526

27+
def extract_local_tensor(t: torch.Tensor) -> torch.Tensor:
28+
"""
29+
Returns a cloned version of the input tensor. If the input tensor is a DTensor,
30+
it extracts and clones its local representation.
31+
"""
32+
new_tensor = None
33+
if isinstance(t, DTensor):
34+
new_tensor = t.to_local().clone()
35+
else:
36+
new_tensor = t.clone()
37+
new_tensor.grad = None
38+
return new_tensor
39+
40+
2641
class LocalSGD:
2742
"""
2843
LocalSGD is a context manager that
@@ -126,7 +141,15 @@ def _perform_sync(self) -> None:
126141
if self._manager.should_commit():
127142
# Update the model parameters with the averaged values
128143
for param, avg_param in zip(self._model.parameters(), averaged_parameters):
129-
param.data.copy_(avg_param)
144+
if isinstance(param, DTensor):
145+
# we averaged the local version of the tensor so need to copy it back as a DTensor
146+
param.data.copy_(
147+
DTensor.from_local(
148+
avg_param, param.device_mesh, param.placements
149+
)
150+
)
151+
else:
152+
param.data.copy_(avg_param)
130153

131154
def _average(self) -> list[torch.Tensor]:
132155
"""
@@ -136,8 +159,7 @@ def _average(self) -> list[torch.Tensor]:
136159
averaged_parameters = []
137160
for p in self._model.parameters():
138161
# Create a new tensor to store the averaged parameter
139-
p.data.grad = None
140-
avg_param = p.data.clone()
162+
avg_param = extract_local_tensor(p)
141163
works.append(self._manager.allreduce(avg_param))
142164
averaged_parameters.append(avg_param)
143165
for work in works:
@@ -182,6 +204,8 @@ def __init__(
182204
self._outer_optimizer = outer_optimizer
183205
self.original_parameters: Dict[str, torch.Tensor] = {}
184206
for name, p in self._model.named_parameters():
207+
if isinstance(p, DTensor):
208+
p = extract_local_tensor(p.data)
185209
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=self._backup_device)
186210
if (
187211
self._pin_memory
@@ -198,13 +222,23 @@ def _save_parameters(self) -> None:
198222
with torch.no_grad():
199223
# TODO: consider running copy on a separate stream
200224
for name, p in self._model.named_parameters():
201-
self.original_parameters[name].copy_(p.data, non_blocking=True)
225+
param_to_local = extract_local_tensor(p.data)
226+
self.original_parameters[name].copy_(param_to_local, non_blocking=True)
202227

203228
def _restore_parameters(self) -> None:
204229
with torch.no_grad():
205230
# TODO: consider running copy on a separate stream
206231
for name, p in self._model.named_parameters():
207-
p.data.copy_(self.original_parameters[name], non_blocking=False)
232+
if isinstance(p, DTensor):
233+
# we averaged the local version of the tensor so need to copy it back as a DTensor
234+
p.data.copy_(
235+
DTensor.from_local(
236+
self.original_parameters[name], p.device_mesh, p.placements
237+
),
238+
non_blocking=False,
239+
)
240+
else:
241+
p.data.copy_(self.original_parameters[name], non_blocking=False)
208242

209243
def __enter__(self) -> "DiLoCo":
210244
# Add optimizer hook which increments the local step counter and syncs if necessary
@@ -252,8 +286,12 @@ def _perform_sync(self) -> None:
252286
"""
253287
# Set the .grad field of each parameter to its pseudogradient
254288
for name, p in self._model.named_parameters():
255-
pseudogradient = p.data - self.original_parameters[name]
256-
p.grad = pseudogradient
289+
local_param = extract_local_tensor(p.data)
290+
pseudogradient = local_param - self.original_parameters[name]
291+
if isinstance(p, DTensor):
292+
p.grad._local_tensor = pseudogradient
293+
else:
294+
p.grad = pseudogradient
257295

258296
self._average_grads()
259297
# Restore the parameters back to the previous state
@@ -272,7 +310,10 @@ def _average_grads(self) -> None:
272310
for p in self._model.parameters():
273311
# Perform allreduce on the pseudogradients
274312
assert p.grad is not None
275-
work = self._manager.allreduce(p.grad)
313+
if isinstance(p, DTensor):
314+
work = self._manager.allreduce(p.grad._local_tensor)
315+
else:
316+
work = self._manager.allreduce(p.grad)
276317
works.append(work)
277318
# Wait for all allreduce operations to complete
278319
for work in works:

torchft/local_sgd_test.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
from typing import Dict
88
from unittest import TestCase
9-
from unittest.mock import create_autospec
9+
from unittest.mock import MagicMock, create_autospec
1010

1111
import torch
1212
from torch import nn, optim
13+
from torch.distributed.tensor import DTensor
1314

14-
from torchft.local_sgd import DiLoCo, LocalSGD
15+
from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
1516
from torchft.manager import Manager
1617

1718

@@ -62,6 +63,23 @@ def test_local_sgd_healthy(self) -> None:
6263
self.assertEqual(manager.should_commit.call_count, 1)
6364
self.assertEqual(manager.allreduce.call_count, 4)
6465

66+
def test_extract_local_tensor(self) -> None:
67+
regular_tensor = torch.rand(3, 3, requires_grad=True)
68+
regular_result = extract_local_tensor(regular_tensor)
69+
70+
self.assertTrue(torch.equal(regular_result, regular_tensor))
71+
self.assertIsNone(regular_result.grad)
72+
self.assertNotEqual(id(regular_result), id(regular_tensor))
73+
local_tensor = torch.rand(3, 3, requires_grad=True)
74+
dtensor = MagicMock(spec=DTensor)
75+
dtensor.to_local.return_value = local_tensor
76+
dtensor_result = extract_local_tensor(dtensor)
77+
78+
self.assertTrue(torch.equal(dtensor_result, local_tensor))
79+
self.assertIsNone(dtensor_result.grad)
80+
self.assertNotEqual(id(dtensor_result), id(local_tensor))
81+
dtensor.to_local.assert_called_once()
82+
6583
def test_local_sgd_recovery(self) -> None:
6684
model = SimpleModel()
6785
optimizer = optim.SGD(model.parameters())

0 commit comments

Comments
 (0)