Skip to content
Merged
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
accumulate=(
accumulate_wgrad_into_param_main_grad
if not getattr(weights[0], "overwrite_main_grad", False)
else False
),
)
# WGRAD
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
Expand Down Expand Up @@ -519,7 +523,9 @@ class GroupedLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
Expand Down
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,11 @@ def backward(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
Expand Down Expand Up @@ -1125,7 +1129,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
Expand Down
16 changes: 13 additions & 3 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,11 @@ def backward(
else ctx.activation_dtype
),
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc1_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
Expand Down Expand Up @@ -1189,7 +1193,11 @@ def fc2_wgrad_gemm(
else ctx.activation_dtype
),
"quantization_params": ctx.fc1_grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc2_weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
Expand Down Expand Up @@ -1484,7 +1492,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
Expand Down
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
else False
),
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
Expand Down Expand Up @@ -1061,7 +1065,9 @@ class Linear(TransformerEngineBaseModule):
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
size to accumulate gradients in. This argument along with
weight tensor having attribute 'overwrite_main_grad' set to True
will overwrite `main_grad` instead of accumulating.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ class BasicLinear(BasicOperation):
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful. This is primarily intented to integrate with
Megatron-LM.
Megatron-LM. This argument along with weight tensor having
attribute 'overwrite_main_grad' set to True will overwrite
`main_grad` instead of accumulating.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
Expand Down Expand Up @@ -1019,6 +1021,7 @@ def op_backward(
weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def fuser_backward(
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def fuser_backward(
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def fuser_backward(
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
Expand Down
Loading