5
5
from collections import OrderedDict
6
6
import math
7
7
import os
8
- from typing import Dict , List , Optional
8
+ from typing import Dict , List , Tuple , Optional
9
9
import pytest
10
10
import copy
11
11
import random
@@ -331,9 +331,9 @@ def __init__(
331
331
in_features : int ,
332
332
out_features : int ,
333
333
eps : float ,
334
- bias : bool = True ,
335
334
normalization : str = "LayerNorm" ,
336
335
zero_centered_gamma : bool = False ,
336
+ bias : bool = True ,
337
337
):
338
338
super ().__init__ ()
339
339
if normalization == "LayerNorm" :
@@ -347,7 +347,7 @@ def __init__(
347
347
else :
348
348
raise RuntimeError ("Unsupported normalization" )
349
349
350
- self .linear = nn .Linear (in_features , out_features )
350
+ self .linear = nn .Linear (in_features , out_features , bias = bias )
351
351
352
352
def forward (self , x : torch .Tensor ) -> torch .Tensor :
353
353
return self .linear (self .layernorm (x ))
@@ -447,6 +447,7 @@ def __init__(
447
447
eps : float = 1e-5 ,
448
448
activation = "gelu" ,
449
449
normalization : str = "LayerNorm" ,
450
+ bias : bool = True ,
450
451
):
451
452
super ().__init__ ()
452
453
if normalization == "LayerNorm" :
@@ -462,8 +463,8 @@ def __init__(
462
463
fc1_output_features = ffn_hidden_size
463
464
self .gelu = _supported_act [activation ]
464
465
465
- self .fc1 = nn .Linear (hidden_size , fc1_output_features )
466
- self .fc2 = nn .Linear (ffn_hidden_size , hidden_size )
466
+ self .fc1 = nn .Linear (hidden_size , fc1_output_features , bias = bias )
467
+ self .fc2 = nn .Linear (ffn_hidden_size , hidden_size , bias = bias )
467
468
468
469
def forward (self , x ):
469
470
t = self .gelu (self .fc1 (self .ln (x )))
@@ -1039,6 +1040,8 @@ def _test_granular_accuracy(block, bs, dtype, config):
1039
1040
inp_hidden_states .retain_grad ()
1040
1041
1041
1042
out = block (inp_hidden_states )
1043
+ if isinstance (out , (List , Tuple )):
1044
+ out = out [0 ]
1042
1045
loss = out .sum ()
1043
1046
loss .backward ()
1044
1047
@@ -1117,32 +1120,53 @@ def test_dpa_accuracy(dtype, bs, model):
1117
1120
assert_allclose (te_output , torch_output , atol = 5e-2 , rtol = 1e-2 )
1118
1121
1119
1122
1123
+ class TestReturnBiasModule (nn .Module ):
1124
+ def __init__ (self , mod , ** kwargs ):
1125
+ super ().__init__ ()
1126
+ self .te_module = mod (** kwargs )
1127
+ self .return_bias = kwargs ["return_bias" ]
1128
+ self .bias = kwargs ["bias" ]
1129
+
1130
+ def forward (self , x ):
1131
+ if self .return_bias :
1132
+ out , bias = self .te_module (x )
1133
+ if self .bias :
1134
+ out = out + bias
1135
+ return out
1136
+ return self .te_module (x )
1137
+
1138
+
1120
1139
@pytest .mark .parametrize ("dtype" , param_types )
1121
1140
@pytest .mark .parametrize ("bs" , batch_sizes )
1122
1141
@pytest .mark .parametrize ("model" , ["small" ])
1123
- def test_linear_accuracy (dtype , bs , model ):
1142
+ @pytest .mark .parametrize ("return_bias" , all_boolean )
1143
+ @pytest .mark .parametrize ("bias" , all_boolean )
1144
+ def test_linear_accuracy (dtype , bs , model , return_bias , bias ):
1124
1145
config = model_configs [model ]
1125
1146
1126
- te_linear = Linear (
1127
- config . hidden_size ,
1128
- 4 * config .hidden_size ,
1129
- bias = True ,
1147
+ te_linear = TestReturnBiasModule (
1148
+ Linear ,
1149
+ in_features = config .hidden_size ,
1150
+ out_features = 4 * config . hidden_size ,
1130
1151
params_dtype = dtype ,
1152
+ return_bias = return_bias ,
1153
+ bias = bias ,
1131
1154
device = "cuda" ,
1132
- ). eval ()
1155
+ )
1133
1156
1134
1157
torch_linear = torch .nn .Linear (
1135
1158
config .hidden_size ,
1136
1159
4 * config .hidden_size ,
1137
- bias = True ,
1160
+ bias = bias ,
1138
1161
device = "cuda" ,
1139
1162
dtype = dtype ,
1140
- ). eval ()
1163
+ )
1141
1164
1142
1165
# Share params
1143
1166
with torch .no_grad ():
1144
- torch_linear .weight = Parameter (te_linear .weight .clone ())
1145
- torch_linear .bias = Parameter (te_linear .bias .clone ())
1167
+ torch_linear .weight = Parameter (te_linear .te_module .weight .clone ())
1168
+ if bias :
1169
+ torch_linear .bias = Parameter (te_linear .te_module .bias .clone ())
1146
1170
1147
1171
te_outputs = _test_granular_accuracy (te_linear , bs , dtype , config )
1148
1172
torch_outputs = _test_granular_accuracy (torch_linear , bs , dtype , config )
@@ -1265,41 +1289,51 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
1265
1289
@pytest .mark .parametrize ("model" , ["small" ])
1266
1290
@pytest .mark .parametrize ("normalization" , all_normalizations )
1267
1291
@pytest .mark .parametrize ("zero_centered_gamma" , all_boolean )
1268
- def test_layernorm_linear_accuracy (dtype , bs , model , normalization , zero_centered_gamma ):
1292
+ @pytest .mark .parametrize ("return_bias" , all_boolean )
1293
+ @pytest .mark .parametrize ("bias" , all_boolean )
1294
+ def test_layernorm_linear_accuracy (
1295
+ dtype , bs , model , normalization , zero_centered_gamma , return_bias , bias
1296
+ ):
1269
1297
config = model_configs [model ]
1270
1298
1271
- te_ln_linear = LayerNormLinear (
1272
- config . hidden_size ,
1273
- 4 * config .hidden_size ,
1274
- config .eps ,
1275
- bias = True ,
1299
+ te_ln_linear = TestReturnBiasModule (
1300
+ LayerNormLinear ,
1301
+ in_features = config .hidden_size ,
1302
+ out_features = 4 * config .hidden_size ,
1303
+ eps = config . eps ,
1276
1304
normalization = normalization ,
1277
1305
params_dtype = dtype ,
1278
1306
zero_centered_gamma = zero_centered_gamma ,
1307
+ return_bias = return_bias ,
1308
+ bias = bias ,
1279
1309
device = "cuda" ,
1280
- ). eval ()
1310
+ )
1281
1311
1282
1312
torch_ln_linear = (
1283
1313
TorchLayerNormLinear (
1284
1314
config .hidden_size ,
1285
1315
4 * config .hidden_size ,
1286
1316
config .eps ,
1287
- bias = True ,
1288
1317
normalization = normalization ,
1289
1318
zero_centered_gamma = zero_centered_gamma ,
1319
+ bias = bias ,
1290
1320
)
1291
1321
.to (dtype = dtype )
1292
1322
.cuda ()
1293
- .eval ()
1294
1323
)
1295
1324
1296
1325
# Share params
1297
1326
with torch .no_grad ():
1298
- torch_ln_linear .layernorm .weight = Parameter (te_ln_linear .layer_norm_weight .clone ())
1327
+ torch_ln_linear .layernorm .weight = Parameter (
1328
+ te_ln_linear .te_module .layer_norm_weight .clone ()
1329
+ )
1299
1330
if normalization != "RMSNorm" :
1300
- torch_ln_linear .layernorm .bias = Parameter (te_ln_linear .layer_norm_bias .clone ())
1301
- torch_ln_linear .linear .weight = Parameter (te_ln_linear .weight .clone ())
1302
- torch_ln_linear .linear .bias = Parameter (te_ln_linear .bias .clone ())
1331
+ torch_ln_linear .layernorm .bias = Parameter (
1332
+ te_ln_linear .te_module .layer_norm_bias .clone ()
1333
+ )
1334
+ torch_ln_linear .linear .weight = Parameter (te_ln_linear .te_module .weight .clone ())
1335
+ if bias :
1336
+ torch_ln_linear .linear .bias = Parameter (te_ln_linear .te_module .bias .clone ())
1303
1337
1304
1338
te_outputs = _test_granular_accuracy (te_ln_linear , bs , dtype , config )
1305
1339
torch_outputs = _test_granular_accuracy (torch_ln_linear , bs , dtype , config )
@@ -1339,39 +1373,45 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
1339
1373
@pytest .mark .parametrize ("model" , ["small" ])
1340
1374
@pytest .mark .parametrize ("activation" , all_activations )
1341
1375
@pytest .mark .parametrize ("normalization" , all_normalizations )
1342
- def test_layernorm_mlp_accuracy (dtype , bs , model , activation , normalization ):
1376
+ @pytest .mark .parametrize ("return_bias" , all_boolean )
1377
+ @pytest .mark .parametrize ("bias" , all_boolean )
1378
+ def test_layernorm_mlp_accuracy (dtype , bs , model , activation , normalization , return_bias , bias ):
1343
1379
config = model_configs [model ]
1344
1380
1345
- te_ln_mlp = LayerNormMLP (
1346
- config .hidden_size ,
1347
- 4 * config .hidden_size ,
1381
+ te_ln_mlp = TestReturnBiasModule (
1382
+ LayerNormMLP ,
1383
+ hidden_size = config .hidden_size ,
1384
+ ffn_hidden_size = 4 * config .hidden_size ,
1348
1385
activation = activation ,
1349
1386
normalization = normalization ,
1350
1387
params_dtype = dtype ,
1388
+ return_bias = return_bias ,
1389
+ bias = bias ,
1351
1390
device = "cuda" ,
1352
- ). eval ()
1391
+ )
1353
1392
1354
1393
torch_ln_mlp = (
1355
1394
TorchLayerNormMLP (
1356
1395
config .hidden_size ,
1357
1396
4 * config .hidden_size ,
1358
1397
activation = activation ,
1359
1398
normalization = normalization ,
1399
+ bias = bias ,
1360
1400
)
1361
1401
.to (dtype = dtype )
1362
1402
.cuda ()
1363
- .eval ()
1364
1403
)
1365
1404
1366
1405
# Share params
1367
1406
with torch .no_grad ():
1368
- torch_ln_mlp .ln .weight = Parameter (te_ln_mlp .layer_norm_weight .clone ())
1407
+ torch_ln_mlp .ln .weight = Parameter (te_ln_mlp .te_module . layer_norm_weight .clone ())
1369
1408
if normalization != "RMSNorm" :
1370
- torch_ln_mlp .ln .bias = Parameter (te_ln_mlp .layer_norm_bias .clone ())
1371
- torch_ln_mlp .fc1 .weight = Parameter (te_ln_mlp .fc1_weight .clone ())
1372
- torch_ln_mlp .fc1 .bias = Parameter (te_ln_mlp .fc1_bias .clone ())
1373
- torch_ln_mlp .fc2 .weight = Parameter (te_ln_mlp .fc2_weight .clone ())
1374
- torch_ln_mlp .fc2 .bias = Parameter (te_ln_mlp .fc2_bias .clone ())
1409
+ torch_ln_mlp .ln .bias = Parameter (te_ln_mlp .te_module .layer_norm_bias .clone ())
1410
+ torch_ln_mlp .fc1 .weight = Parameter (te_ln_mlp .te_module .fc1_weight .clone ())
1411
+ torch_ln_mlp .fc2 .weight = Parameter (te_ln_mlp .te_module .fc2_weight .clone ())
1412
+ if bias :
1413
+ torch_ln_mlp .fc1 .bias = Parameter (te_ln_mlp .te_module .fc1_bias .clone ())
1414
+ torch_ln_mlp .fc2 .bias = Parameter (te_ln_mlp .te_module .fc2_bias .clone ())
1375
1415
1376
1416
te_outputs = _test_granular_accuracy (te_ln_mlp , bs , dtype , config )
1377
1417
torch_outputs = _test_granular_accuracy (torch_ln_mlp , bs , dtype , config )
0 commit comments