Skip to content

Commit b59d1d8

Browse files
BestJulyksivaman
andauthored
[PyTorch] Fix issues for MCore DDP in grouped GEMM. (#1609)
fix mcore DDP error Signed-off-by: lit <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 945a559 commit b59d1d8

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

Diff for: transformer_engine/pytorch/module/grouped_linear.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def forward(
7878
skip_fp8_weight_update,
7979
*weights_and_biases,
8080
) -> torch.Tensor:
81-
8281
# pylint: disable=missing-function-docstring
82+
8383
num_gemms = len(m_splits)
8484
weights = weights_and_biases[:num_gemms]
8585
biases = weights_and_biases[num_gemms:]
@@ -180,7 +180,12 @@ def forward(
180180

181181
ctx.weights_shape_1 = weights[0].shape[1]
182182

183-
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
183+
tensors_to_save, tensor_objects = prepare_for_saving(
184+
*inputmats,
185+
*weights_fp8,
186+
*weights,
187+
*biases,
188+
)
184189
ctx.save_for_backward(*tensors_to_save)
185190
ctx.tensor_objects = tensor_objects
186191

@@ -220,7 +225,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
220225
N = ctx.num_gemms
221226
inputmats = saved_tensors[:N]
222227
weights = saved_tensors[N : 2 * N]
223-
biases = saved_tensors[2 * N : 3 * N]
228+
origin_weights = saved_tensors[2 * N : 3 * N]
229+
biases = saved_tensors[3 * N : 4 * N]
224230
main_grads = ctx.main_grads
225231

226232
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
@@ -311,21 +317,24 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
311317
# Deallocate input tensor
312318
clear_tensor_data(*inputmats)
313319

314-
def handle_custom_ddp_from_mcore(w, wgrad):
320+
def handle_custom_ddp_from_mcore(weight, wgrad):
315321
if ctx.weights_requires_grad:
316-
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
317-
w.grad_added_to_main_grad = True
318-
if getattr(w, "zero_out_wgrad", False):
322+
# Handle custom DDP from mcore.
323+
if ctx.fuse_wgrad_accumulation and hasattr(
324+
weight, "grad_added_to_main_grad"
325+
):
326+
weight.grad_added_to_main_grad = True
327+
if getattr(weight, "zero_out_wgrad", False):
319328
wgrad = torch.zeros(
320-
w.main_grad.shape,
321-
dtype=w.dtype,
329+
weight.main_grad.shape,
330+
dtype=weight.dtype,
322331
device=torch.cuda.current_device(),
323332
requires_grad=False,
324333
)
325334
else:
326335
wgrad = torch.empty(
327-
w.main_grad.shape,
328-
dtype=w.dtype,
336+
weight.main_grad.shape,
337+
dtype=weight.dtype,
329338
device=torch.cuda.current_device(),
330339
requires_grad=False,
331340
)
@@ -336,7 +345,8 @@ def handle_custom_ddp_from_mcore(w, wgrad):
336345
return wgrad
337346

338347
wgrad_list = [
339-
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
348+
handle_custom_ddp_from_mcore(weight, wgrad)
349+
for weight, wgrad in zip(origin_weights, wgrad_list)
340350
]
341351
else:
342352
wgrad_list = [None] * ctx.num_gemms

0 commit comments

Comments
 (0)