14
14
15
15
import torch
16
16
from torch import nn , optim
17
+ from torch .distributed .tensor import DTensor
17
18
from torch .nn .parameter import Parameter
18
19
from torch .optim .optimizer import Optimizer
19
20
from torch .utils .hooks import RemovableHandle
23
24
logger : logging .Logger = logging .getLogger (__name__ )
24
25
25
26
27
+ def extract_local_tensor (t : torch .Tensor ) -> torch .Tensor :
28
+ """
29
+ Returns a cloned version of the input tensor. If the input tensor is a DTensor,
30
+ it extracts and clones its local representation.
31
+ """
32
+ new_tensor = None
33
+ if isinstance (t , DTensor ):
34
+ new_tensor = t .to_local ().clone ()
35
+ else :
36
+ new_tensor = t .clone ()
37
+ new_tensor .grad = None
38
+ return new_tensor
39
+
40
+
26
41
class LocalSGD :
27
42
"""
28
43
LocalSGD is a context manager that
@@ -126,7 +141,15 @@ def _perform_sync(self) -> None:
126
141
if self ._manager .should_commit ():
127
142
# Update the model parameters with the averaged values
128
143
for param , avg_param in zip (self ._model .parameters (), averaged_parameters ):
129
- param .data .copy_ (avg_param )
144
+ if isinstance (param , DTensor ):
145
+ # we averaged the local version of the tensor so need to copy it back as a DTensor
146
+ param .data .copy_ (
147
+ DTensor .from_local (
148
+ avg_param , param .device_mesh , param .placements
149
+ )
150
+ )
151
+ else :
152
+ param .data .copy_ (avg_param )
130
153
131
154
def _average (self ) -> list [torch .Tensor ]:
132
155
"""
@@ -136,8 +159,7 @@ def _average(self) -> list[torch.Tensor]:
136
159
averaged_parameters = []
137
160
for p in self ._model .parameters ():
138
161
# Create a new tensor to store the averaged parameter
139
- p .data .grad = None
140
- avg_param = p .data .clone ()
162
+ avg_param = extract_local_tensor (p )
141
163
works .append (self ._manager .allreduce (avg_param ))
142
164
averaged_parameters .append (avg_param )
143
165
for work in works :
@@ -182,6 +204,8 @@ def __init__(
182
204
self ._outer_optimizer = outer_optimizer
183
205
self .original_parameters : Dict [str , torch .Tensor ] = {}
184
206
for name , p in self ._model .named_parameters ():
207
+ if isinstance (p , DTensor ):
208
+ p = extract_local_tensor (p .data )
185
209
t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = self ._backup_device )
186
210
if (
187
211
self ._pin_memory
@@ -198,13 +222,23 @@ def _save_parameters(self) -> None:
198
222
with torch .no_grad ():
199
223
# TODO: consider running copy on a separate stream
200
224
for name , p in self ._model .named_parameters ():
201
- self .original_parameters [name ].copy_ (p .data , non_blocking = True )
225
+ param_to_local = extract_local_tensor (p .data )
226
+ self .original_parameters [name ].copy_ (param_to_local , non_blocking = True )
202
227
203
228
def _restore_parameters (self ) -> None :
204
229
with torch .no_grad ():
205
230
# TODO: consider running copy on a separate stream
206
231
for name , p in self ._model .named_parameters ():
207
- p .data .copy_ (self .original_parameters [name ], non_blocking = False )
232
+ if isinstance (p , DTensor ):
233
+ # we averaged the local version of the tensor so need to copy it back as a DTensor
234
+ p .data .copy_ (
235
+ DTensor .from_local (
236
+ self .original_parameters [name ], p .device_mesh , p .placements
237
+ ),
238
+ non_blocking = False ,
239
+ )
240
+ else :
241
+ p .data .copy_ (self .original_parameters [name ], non_blocking = False )
208
242
209
243
def __enter__ (self ) -> "DiLoCo" :
210
244
# Add optimizer hook which increments the local step counter and syncs if necessary
@@ -252,8 +286,12 @@ def _perform_sync(self) -> None:
252
286
"""
253
287
# Set the .grad field of each parameter to its pseudogradient
254
288
for name , p in self ._model .named_parameters ():
255
- pseudogradient = p .data - self .original_parameters [name ]
256
- p .grad = pseudogradient
289
+ local_param = extract_local_tensor (p .data )
290
+ pseudogradient = local_param - self .original_parameters [name ]
291
+ if isinstance (p , DTensor ):
292
+ p .grad ._local_tensor = pseudogradient
293
+ else :
294
+ p .grad = pseudogradient
257
295
258
296
self ._average_grads ()
259
297
# Restore the parameters back to the previous state
@@ -272,7 +310,10 @@ def _average_grads(self) -> None:
272
310
for p in self ._model .parameters ():
273
311
# Perform allreduce on the pseudogradients
274
312
assert p .grad is not None
275
- work = self ._manager .allreduce (p .grad )
313
+ if isinstance (p , DTensor ):
314
+ work = self ._manager .allreduce (p .grad ._local_tensor )
315
+ else :
316
+ work = self ._manager .allreduce (p .grad )
276
317
works .append (work )
277
318
# Wait for all allreduce operations to complete
278
319
for work in works :
0 commit comments