1
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
2
3
+ import importlib
3
4
import numbers
5
+
4
6
import torch
5
- from torch .nn .parameter import Parameter
6
7
from torch .nn import init
7
- import importlib
8
+ from torch . nn . parameter import Parameter
8
9
9
10
from megatron .core .utils import make_viewless_tensor
10
11
11
12
try :
12
13
from apex .contrib .layer_norm .layer_norm import FastLayerNormFN
14
+
13
15
HAVE_PERSIST_LAYER_NORM = True
14
16
except :
15
17
HAVE_PERSIST_LAYER_NORM = False
16
18
17
19
try :
18
20
from apex .normalization .fused_layer_norm import FusedLayerNormAffineFunction
21
+
19
22
HAVE_FUSED_LAYER_NORM = True
20
23
except :
21
24
HAVE_FUSED_LAYER_NORM = False
22
25
23
26
24
27
class FusedLayerNorm (torch .nn .Module ):
25
-
26
- def __init__ (self , hidden_size , eps = 1e-5 ,
27
- persist_layer_norm = True ,
28
- sequence_parallel = False ,
29
- zero_centered_gamma = False ):
28
+ def __init__ (
29
+ self ,
30
+ hidden_size ,
31
+ eps = 1e-5 ,
32
+ persist_layer_norm = True ,
33
+ sequence_parallel = False ,
34
+ zero_centered_gamma = False ,
35
+ ):
30
36
super ().__init__ ()
31
37
32
38
self .zero_centered_gamma = zero_centered_gamma
33
39
34
40
# List of hiddens sizes supported in the persistent layer norm kernel
35
41
# If the hidden size is not supported, fall back to the non-persistent
36
42
# kernel.
37
- persist_ln_hidden_sizes = [1024 , 1536 , 2048 , 2304 , 3072 , 3840 , 4096 ,
38
- 5120 , 6144 , 8192 , 10240 , 12288 , 12800 , 15360 , 16384 , 18432 , 20480 ,
39
- 24576 , 25600 , 30720 , 32768 , 40960 , 49152 , 65536 ]
43
+ persist_ln_hidden_sizes = [
44
+ 1024 ,
45
+ 1536 ,
46
+ 2048 ,
47
+ 2304 ,
48
+ 3072 ,
49
+ 3840 ,
50
+ 4096 ,
51
+ 5120 ,
52
+ 6144 ,
53
+ 8192 ,
54
+ 10240 ,
55
+ 12288 ,
56
+ 12800 ,
57
+ 15360 ,
58
+ 16384 ,
59
+ 18432 ,
60
+ 20480 ,
61
+ 24576 ,
62
+ 25600 ,
63
+ 30720 ,
64
+ 32768 ,
65
+ 40960 ,
66
+ 49152 ,
67
+ 65536 ,
68
+ ]
40
69
if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM :
41
70
persist_layer_norm = False
42
71
@@ -58,32 +87,33 @@ def __init__(self, hidden_size, eps=1e-5,
58
87
setattr (self .weight , 'sequence_parallel' , self .sequence_parallel )
59
88
setattr (self .bias , 'sequence_parallel' , self .sequence_parallel )
60
89
90
+ def reset_parameters (self ):
61
91
62
- def reset_parameters (self ):
63
-
64
- if self .zero_centered_gamma :
65
- init .zeros_ (self .weight )
66
- init .zeros_ (self .bias )
67
- else :
68
- init .ones_ (self .weight )
69
- init .zeros_ (self .bias )
92
+ if self .zero_centered_gamma :
93
+ init .zeros_ (self .weight )
94
+ init .zeros_ (self .bias )
95
+ else :
96
+ init .ones_ (self .weight )
97
+ init .zeros_ (self .bias )
70
98
71
- def forward (self , input ):
99
+ def forward (self , input ):
72
100
73
- weight = self .weight + 1 if self .zero_centered_gamma else self .weight
101
+ weight = self .weight + 1 if self .zero_centered_gamma else self .weight
74
102
75
- if self .persist_layer_norm :
76
- output = FastLayerNormFN .apply (input , weight , self .bias , self .eps )
103
+ if self .persist_layer_norm :
104
+ output = FastLayerNormFN .apply (input , weight , self .bias , self .eps )
77
105
78
- # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
79
- # a populated '_base' field). This will result in schedule.py's
80
- # deallocate_output_tensor() throwing an error, so a viewless tensor is
81
- # created to prevent this.
82
- output = make_viewless_tensor (inp = output ,
83
- requires_grad = input .requires_grad ,
84
- keep_graph = True )
106
+ # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
107
+ # a populated '_base' field). This will result in schedule.py's
108
+ # deallocate_output_tensor() throwing an error, so a viewless tensor is
109
+ # created to prevent this.
110
+ output = make_viewless_tensor (
111
+ inp = output , requires_grad = input .requires_grad , keep_graph = True
112
+ )
85
113
86
- else :
87
- output = FusedLayerNormAffineFunction .apply (input , weight , self .bias , self .hidden_size , self .eps )
114
+ else :
115
+ output = FusedLayerNormAffineFunction .apply (
116
+ input , weight , self .bias , self .hidden_size , self .eps
117
+ )
88
118
89
- return output
119
+ return output
0 commit comments