Skip to content

Commit 1321b9b

Browse files
guyueh1pre-commit-ci[bot]timmoon10
authored
Ensure weight transpose is valid for Hopper FP8 training (#1596)
* Update usage of weightmat before saving for backward Signed-off-by: Guyue Huang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix for layernorm mlp Signed-off-by: Guyue Huang <[email protected]> --------- Signed-off-by: Guyue Huang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]>
1 parent e14d147 commit 1321b9b

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def forward(
341341
weight.requires_grad and parallel_mode == "column" and sequence_parallel
342342
)
343343

344-
# Input with column-wise usage is needed for dgrad GEMM.
344+
# Input with column-wise usage is needed for wgrad GEMM.
345345
if backward_needs_input:
346346
if isinstance(ln_out, QuantizedTensor):
347347
# For sequence parallel in vanilla FP8, rowwise data is
@@ -350,6 +350,11 @@ def forward(
350350
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
351351
ln_out.update_usage(rowwise_usage=False)
352352

353+
# Weight with column-wise usage is needed for dgrad GEMM.
354+
if inp.requires_grad:
355+
if isinstance(weightmat, QuantizedTensor):
356+
weightmat.update_usage(columnwise_usage=True)
357+
353358
if cpu_offloading:
354359
if fp8 and weightmat is not None:
355360
set_offloading_param(weightmat, "weight_offloading", True)

transformer_engine/pytorch/module/layernorm_mlp.py

+8
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,14 @@ def forward(
442442
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
443443
extra_output=rs_out,
444444
)
445+
446+
# Weight with column-wise usage is needed for dgrad GEMM.
447+
if is_grad_enabled and inp.requires_grad:
448+
if isinstance(fc1_weight_final, QuantizedTensor):
449+
fc1_weight_final.update_usage(columnwise_usage=True)
450+
if isinstance(fc2_weight_final, QuantizedTensor):
451+
fc2_weight_final.update_usage(columnwise_usage=True)
452+
445453
if not is_grad_enabled:
446454
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
447455
else:

transformer_engine/pytorch/module/linear.py

+5
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ def forward(
272272
inputmat.update_usage(rowwise_usage=False)
273273
saved_inputmat = inputmat
274274

275+
# Weight with column-wise usage is needed for dgrad GEMM.
276+
if inp.requires_grad:
277+
if isinstance(weightmat, QuantizedTensor):
278+
weightmat.update_usage(columnwise_usage=True)
279+
275280
if cpu_offloading:
276281
set_offloading_param(weight, "weight_offloading", True)
277282
set_offloading_param(weightmat, "weight_offloading", True)

0 commit comments

Comments
 (0)