@@ -78,8 +78,8 @@ def forward(
78
78
skip_fp8_weight_update ,
79
79
* weights_and_biases ,
80
80
) -> torch .Tensor :
81
-
82
81
# pylint: disable=missing-function-docstring
82
+
83
83
num_gemms = len (m_splits )
84
84
weights = weights_and_biases [:num_gemms ]
85
85
biases = weights_and_biases [num_gemms :]
@@ -180,7 +180,12 @@ def forward(
180
180
181
181
ctx .weights_shape_1 = weights [0 ].shape [1 ]
182
182
183
- tensors_to_save , tensor_objects = prepare_for_saving (* inputmats , * weights_fp8 , * biases )
183
+ tensors_to_save , tensor_objects = prepare_for_saving (
184
+ * inputmats ,
185
+ * weights_fp8 ,
186
+ * weights ,
187
+ * biases ,
188
+ )
184
189
ctx .save_for_backward (* tensors_to_save )
185
190
ctx .tensor_objects = tensor_objects
186
191
@@ -220,7 +225,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
220
225
N = ctx .num_gemms
221
226
inputmats = saved_tensors [:N ]
222
227
weights = saved_tensors [N : 2 * N ]
223
- biases = saved_tensors [2 * N : 3 * N ]
228
+ origin_weights = saved_tensors [2 * N : 3 * N ]
229
+ biases = saved_tensors [3 * N : 4 * N ]
224
230
main_grads = ctx .main_grads
225
231
226
232
if ctx .cpu_offloading and ctx .fuse_wgrad_accumulation : # TOSO
@@ -311,21 +317,24 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
311
317
# Deallocate input tensor
312
318
clear_tensor_data (* inputmats )
313
319
314
- def handle_custom_ddp_from_mcore (w , wgrad ):
320
+ def handle_custom_ddp_from_mcore (weight , wgrad ):
315
321
if ctx .weights_requires_grad :
316
- if ctx .fuse_wgrad_accumulation and hasattr (w , "grad_added_to_main_grad" ):
317
- w .grad_added_to_main_grad = True
318
- if getattr (w , "zero_out_wgrad" , False ):
322
+ # Handle custom DDP from mcore.
323
+ if ctx .fuse_wgrad_accumulation and hasattr (
324
+ weight , "grad_added_to_main_grad"
325
+ ):
326
+ weight .grad_added_to_main_grad = True
327
+ if getattr (weight , "zero_out_wgrad" , False ):
319
328
wgrad = torch .zeros (
320
- w .main_grad .shape ,
321
- dtype = w .dtype ,
329
+ weight .main_grad .shape ,
330
+ dtype = weight .dtype ,
322
331
device = torch .cuda .current_device (),
323
332
requires_grad = False ,
324
333
)
325
334
else :
326
335
wgrad = torch .empty (
327
- w .main_grad .shape ,
328
- dtype = w .dtype ,
336
+ weight .main_grad .shape ,
337
+ dtype = weight .dtype ,
329
338
device = torch .cuda .current_device (),
330
339
requires_grad = False ,
331
340
)
@@ -336,7 +345,8 @@ def handle_custom_ddp_from_mcore(w, wgrad):
336
345
return wgrad
337
346
338
347
wgrad_list = [
339
- handle_custom_ddp_from_mcore (w , wgrad ) for w , wgrad in zip (weights , wgrad_list )
348
+ handle_custom_ddp_from_mcore (weight , wgrad )
349
+ for weight , wgrad in zip (origin_weights , wgrad_list )
340
350
]
341
351
else :
342
352
wgrad_list = [None ] * ctx .num_gemms
0 commit comments