Skip to content

Commit 99f4067

Browse files
ptrendxpre-commit-ci[bot]timmoon10
authored
Fix return_bias option in LayerNormLinear and LayerNormMLP (#1569)
* Do not apply bias when apply_bias is False Signed-off-by: Przemek Tredak <[email protected]> * Bwd fix for LNMLP and tests Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix for the dbias calculation Signed-off-by: Przemek Tredak <[email protected]> * Improve tests and cleaning the logic Signed-off-by: Przemek Tredak <[email protected]> * Tightened test tolerances a little Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "Tightened test tolerances a little" This reverts commit 2e20a92. Signed-off-by: Przemek Tredak <[email protected]> * Update tests/pytorch/test_numerics.py Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Przemyslaw Tredak <[email protected]> * Fix the Gelu Aux type Signed-off-by: Przemek Tredak <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove use_fc1_bias option Signed-off-by: Przemek Tredak <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Przemyslaw Tredak <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]>
1 parent bee4649 commit 99f4067

File tree

5 files changed

+100
-67
lines changed

5 files changed

+100
-67
lines changed

tests/pytorch/test_numerics.py

Lines changed: 80 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import OrderedDict
66
import math
77
import os
8-
from typing import Dict, List, Optional
8+
from typing import Dict, List, Tuple, Optional
99
import pytest
1010
import copy
1111
import random
@@ -331,9 +331,9 @@ def __init__(
331331
in_features: int,
332332
out_features: int,
333333
eps: float,
334-
bias: bool = True,
335334
normalization: str = "LayerNorm",
336335
zero_centered_gamma: bool = False,
336+
bias: bool = True,
337337
):
338338
super().__init__()
339339
if normalization == "LayerNorm":
@@ -347,7 +347,7 @@ def __init__(
347347
else:
348348
raise RuntimeError("Unsupported normalization")
349349

350-
self.linear = nn.Linear(in_features, out_features)
350+
self.linear = nn.Linear(in_features, out_features, bias=bias)
351351

352352
def forward(self, x: torch.Tensor) -> torch.Tensor:
353353
return self.linear(self.layernorm(x))
@@ -447,6 +447,7 @@ def __init__(
447447
eps: float = 1e-5,
448448
activation="gelu",
449449
normalization: str = "LayerNorm",
450+
bias: bool = True,
450451
):
451452
super().__init__()
452453
if normalization == "LayerNorm":
@@ -462,8 +463,8 @@ def __init__(
462463
fc1_output_features = ffn_hidden_size
463464
self.gelu = _supported_act[activation]
464465

465-
self.fc1 = nn.Linear(hidden_size, fc1_output_features)
466-
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size)
466+
self.fc1 = nn.Linear(hidden_size, fc1_output_features, bias=bias)
467+
self.fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=bias)
467468

468469
def forward(self, x):
469470
t = self.gelu(self.fc1(self.ln(x)))
@@ -1039,6 +1040,8 @@ def _test_granular_accuracy(block, bs, dtype, config):
10391040
inp_hidden_states.retain_grad()
10401041

10411042
out = block(inp_hidden_states)
1043+
if isinstance(out, (List, Tuple)):
1044+
out = out[0]
10421045
loss = out.sum()
10431046
loss.backward()
10441047

@@ -1117,32 +1120,53 @@ def test_dpa_accuracy(dtype, bs, model):
11171120
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)
11181121

11191122

1123+
class TestReturnBiasModule(nn.Module):
1124+
def __init__(self, mod, **kwargs):
1125+
super().__init__()
1126+
self.te_module = mod(**kwargs)
1127+
self.return_bias = kwargs["return_bias"]
1128+
self.bias = kwargs["bias"]
1129+
1130+
def forward(self, x):
1131+
if self.return_bias:
1132+
out, bias = self.te_module(x)
1133+
if self.bias:
1134+
out = out + bias
1135+
return out
1136+
return self.te_module(x)
1137+
1138+
11201139
@pytest.mark.parametrize("dtype", param_types)
11211140
@pytest.mark.parametrize("bs", batch_sizes)
11221141
@pytest.mark.parametrize("model", ["small"])
1123-
def test_linear_accuracy(dtype, bs, model):
1142+
@pytest.mark.parametrize("return_bias", all_boolean)
1143+
@pytest.mark.parametrize("bias", all_boolean)
1144+
def test_linear_accuracy(dtype, bs, model, return_bias, bias):
11241145
config = model_configs[model]
11251146

1126-
te_linear = Linear(
1127-
config.hidden_size,
1128-
4 * config.hidden_size,
1129-
bias=True,
1147+
te_linear = TestReturnBiasModule(
1148+
Linear,
1149+
in_features=config.hidden_size,
1150+
out_features=4 * config.hidden_size,
11301151
params_dtype=dtype,
1152+
return_bias=return_bias,
1153+
bias=bias,
11311154
device="cuda",
1132-
).eval()
1155+
)
11331156

11341157
torch_linear = torch.nn.Linear(
11351158
config.hidden_size,
11361159
4 * config.hidden_size,
1137-
bias=True,
1160+
bias=bias,
11381161
device="cuda",
11391162
dtype=dtype,
1140-
).eval()
1163+
)
11411164

11421165
# Share params
11431166
with torch.no_grad():
1144-
torch_linear.weight = Parameter(te_linear.weight.clone())
1145-
torch_linear.bias = Parameter(te_linear.bias.clone())
1167+
torch_linear.weight = Parameter(te_linear.te_module.weight.clone())
1168+
if bias:
1169+
torch_linear.bias = Parameter(te_linear.te_module.bias.clone())
11461170

11471171
te_outputs = _test_granular_accuracy(te_linear, bs, dtype, config)
11481172
torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config)
@@ -1265,41 +1289,51 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
12651289
@pytest.mark.parametrize("model", ["small"])
12661290
@pytest.mark.parametrize("normalization", all_normalizations)
12671291
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
1268-
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
1292+
@pytest.mark.parametrize("return_bias", all_boolean)
1293+
@pytest.mark.parametrize("bias", all_boolean)
1294+
def test_layernorm_linear_accuracy(
1295+
dtype, bs, model, normalization, zero_centered_gamma, return_bias, bias
1296+
):
12691297
config = model_configs[model]
12701298

1271-
te_ln_linear = LayerNormLinear(
1272-
config.hidden_size,
1273-
4 * config.hidden_size,
1274-
config.eps,
1275-
bias=True,
1299+
te_ln_linear = TestReturnBiasModule(
1300+
LayerNormLinear,
1301+
in_features=config.hidden_size,
1302+
out_features=4 * config.hidden_size,
1303+
eps=config.eps,
12761304
normalization=normalization,
12771305
params_dtype=dtype,
12781306
zero_centered_gamma=zero_centered_gamma,
1307+
return_bias=return_bias,
1308+
bias=bias,
12791309
device="cuda",
1280-
).eval()
1310+
)
12811311

12821312
torch_ln_linear = (
12831313
TorchLayerNormLinear(
12841314
config.hidden_size,
12851315
4 * config.hidden_size,
12861316
config.eps,
1287-
bias=True,
12881317
normalization=normalization,
12891318
zero_centered_gamma=zero_centered_gamma,
1319+
bias=bias,
12901320
)
12911321
.to(dtype=dtype)
12921322
.cuda()
1293-
.eval()
12941323
)
12951324

12961325
# Share params
12971326
with torch.no_grad():
1298-
torch_ln_linear.layernorm.weight = Parameter(te_ln_linear.layer_norm_weight.clone())
1327+
torch_ln_linear.layernorm.weight = Parameter(
1328+
te_ln_linear.te_module.layer_norm_weight.clone()
1329+
)
12991330
if normalization != "RMSNorm":
1300-
torch_ln_linear.layernorm.bias = Parameter(te_ln_linear.layer_norm_bias.clone())
1301-
torch_ln_linear.linear.weight = Parameter(te_ln_linear.weight.clone())
1302-
torch_ln_linear.linear.bias = Parameter(te_ln_linear.bias.clone())
1331+
torch_ln_linear.layernorm.bias = Parameter(
1332+
te_ln_linear.te_module.layer_norm_bias.clone()
1333+
)
1334+
torch_ln_linear.linear.weight = Parameter(te_ln_linear.te_module.weight.clone())
1335+
if bias:
1336+
torch_ln_linear.linear.bias = Parameter(te_ln_linear.te_module.bias.clone())
13031337

13041338
te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config)
13051339
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
@@ -1339,39 +1373,45 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
13391373
@pytest.mark.parametrize("model", ["small"])
13401374
@pytest.mark.parametrize("activation", all_activations)
13411375
@pytest.mark.parametrize("normalization", all_normalizations)
1342-
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
1376+
@pytest.mark.parametrize("return_bias", all_boolean)
1377+
@pytest.mark.parametrize("bias", all_boolean)
1378+
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, return_bias, bias):
13431379
config = model_configs[model]
13441380

1345-
te_ln_mlp = LayerNormMLP(
1346-
config.hidden_size,
1347-
4 * config.hidden_size,
1381+
te_ln_mlp = TestReturnBiasModule(
1382+
LayerNormMLP,
1383+
hidden_size=config.hidden_size,
1384+
ffn_hidden_size=4 * config.hidden_size,
13481385
activation=activation,
13491386
normalization=normalization,
13501387
params_dtype=dtype,
1388+
return_bias=return_bias,
1389+
bias=bias,
13511390
device="cuda",
1352-
).eval()
1391+
)
13531392

13541393
torch_ln_mlp = (
13551394
TorchLayerNormMLP(
13561395
config.hidden_size,
13571396
4 * config.hidden_size,
13581397
activation=activation,
13591398
normalization=normalization,
1399+
bias=bias,
13601400
)
13611401
.to(dtype=dtype)
13621402
.cuda()
1363-
.eval()
13641403
)
13651404

13661405
# Share params
13671406
with torch.no_grad():
1368-
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.layer_norm_weight.clone())
1407+
torch_ln_mlp.ln.weight = Parameter(te_ln_mlp.te_module.layer_norm_weight.clone())
13691408
if normalization != "RMSNorm":
1370-
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.layer_norm_bias.clone())
1371-
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.fc1_weight.clone())
1372-
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.fc1_bias.clone())
1373-
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.fc2_weight.clone())
1374-
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.fc2_bias.clone())
1409+
torch_ln_mlp.ln.bias = Parameter(te_ln_mlp.te_module.layer_norm_bias.clone())
1410+
torch_ln_mlp.fc1.weight = Parameter(te_ln_mlp.te_module.fc1_weight.clone())
1411+
torch_ln_mlp.fc2.weight = Parameter(te_ln_mlp.te_module.fc2_weight.clone())
1412+
if bias:
1413+
torch_ln_mlp.fc1.bias = Parameter(te_ln_mlp.te_module.fc1_bias.clone())
1414+
torch_ln_mlp.fc2.bias = Parameter(te_ln_mlp.te_module.fc2_bias.clone())
13751415

13761416
te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config)
13771417
torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config)

transformer_engine/common/gemm/cublaslt_gemm.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
351351
&pre_gelu_out, sizeof(pre_gelu_out)));
352352
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
353353
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
354+
const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
355+
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
356+
operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof(aux_type)));
354357
}
355358

356359
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,

transformer_engine/pytorch/csrc/extensions/gemm.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ namespace transformer_engine::pytorch {
3636

3737
namespace detail {
3838

39+
bool is_low_precision(const DType type) {
40+
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
41+
}
42+
3943
std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool transa,
4044
const NVTEShape& B_shape, const bool transb) {
4145
// Flatten outer dims to get 2D matrices
@@ -96,6 +100,9 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
96100
TensorWrapper A_tensor = makeTransformerEngineTensor(A, none);
97101
TensorWrapper B_tensor = makeTransformerEngineTensor(B, none);
98102

103+
const bool low_precision =
104+
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
105+
99106
// Check tensor dimensions
100107
const auto& A_shape = A_tensor.shape();
101108
const auto& B_shape = B_tensor.shape();
@@ -137,7 +144,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
137144

138145
// Activation input tensor
139146
MaybeTensor pre_gelu_out = std::nullopt;
140-
DType gelu_type = bias_type;
147+
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
141148
if (gelu) {
142149
if (!grad) {
143150
auto dtype = GetATenDType(gelu_type);

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def forward(
7979
ln_bias: Union[torch.Tensor, None],
8080
weight: torch.Tensor,
8181
bias: torch.Tensor,
82-
use_bias: bool,
8382
eps: float,
8483
is_first_microbatch: Union[bool, None],
8584
fp8: bool,
@@ -422,7 +421,7 @@ def forward(
422421
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
423422
ctx.cpu_offloading = cpu_offloading
424423
ctx.is_first_microbatch = is_first_microbatch
425-
ctx.use_bias = use_bias
424+
ctx.use_bias = bias is not None
426425
ctx.sequence_parallel = sequence_parallel
427426
ctx.tensor_parallel = tensor_parallel
428427
ctx.inp_shape = inp_shape
@@ -756,10 +755,6 @@ def backward(
756755
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
757756
clear_tensor_data(ln_out_total)
758757

759-
# Don't return grad bias if not needed
760-
if not ctx.use_bias:
761-
grad_bias = None
762-
763758
# Synchronize tensor parallel communication
764759
if ln_out_total_work is not None:
765760
ln_out_total_work.wait()
@@ -841,7 +836,6 @@ def backward(
841836
dbeta,
842837
wgrad,
843838
grad_bias,
844-
None, # use_bias
845839
None, # eps
846840
None, # is_first_microbatch
847841
None, # fp8
@@ -1344,8 +1338,7 @@ def forward(
13441338
self.layer_norm_weight,
13451339
self.layer_norm_bias,
13461340
weight_tensor,
1347-
bias_tensor,
1348-
self.apply_bias and not self.gemm_bias_unfused_add,
1341+
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
13491342
self.eps,
13501343
is_first_microbatch,
13511344
self.fp8,

0 commit comments

Comments
 (0)