@@ -19,17 +19,17 @@ class TransformerEncoderLayer(nn.Module):
19
19
def __init__ (self , out_dim : nn .Dim , * ,
20
20
self_attention : Union [nn .SelfAttention , Any ],
21
21
ff_dim : nn .Dim ,
22
+ ff_activation : Callable [[nn .Tensor ], nn .Tensor ] = nn .relu ,
22
23
dropout : float = 0.1 ,
23
- activation : Callable [[nn .Tensor ], nn .Tensor ] = nn .relu ,
24
24
norm_eps : float = 1e-6 ,
25
25
norm_first : bool = True ,
26
26
norm = nn .layer_norm ) -> None :
27
27
"""
28
28
:param out_dim: output dimension, PyTorch name: d_model
29
29
:param self_attention: module which does self attention
30
30
:param ff_dim: dimension of feedforward layer, PyTorch name: dim_feedforward
31
+ :param ff_activation: activation function
31
32
:param dropout: Dropout value, PyTorch name: dropout
32
- :param activation: activation function
33
33
:param norm_eps: Epsilon value for layer normalization
34
34
:param norm_first: if ``True`` will perform normalization before other att and ff operations, otherwise after
35
35
:param norm: normalization function
@@ -39,7 +39,7 @@ def __init__(self, out_dim: nn.Dim, *,
39
39
40
40
self .linear_ff = nn .Linear (ff_dim )
41
41
self .linear_out = nn .Linear (out_dim )
42
- self .activation = activation
42
+ self .activation = ff_activation
43
43
self .norm_first = norm_first
44
44
self .norm_eps = norm_eps
45
45
self .norm = norm
@@ -117,8 +117,8 @@ def __init__(self, out_dim: nn.Dim, *,
117
117
enc_dec_attention : nn .AttentionFunc ,
118
118
causal_self_attention : Union [nn .CausalSelfAttention , Any ],
119
119
ff_dim : nn .Dim ,
120
+ ff_activation : Callable [[nn .Tensor ], nn .Tensor ] = nn .relu ,
120
121
dropout : float = 0.1 ,
121
- activation : Callable [[nn .Tensor ], nn .Tensor ] = nn .relu ,
122
122
norm_eps : float = 1e-6 ,
123
123
norm_first : bool = True ,
124
124
norm = nn .layer_norm ):
@@ -127,8 +127,8 @@ def __init__(self, out_dim: nn.Dim, *,
127
127
:param enc_dec_attention: module or func which does encoder decoder attention
128
128
:param causal_self_attention: module or func which does causal self attention
129
129
:param ff_dim: dimension of feedforward layer, PyTorch name: dim_feedforward
130
+ :param ff_activation: activation function
130
131
:param dropout: Dropout value
131
- :param activation: activation function
132
132
:param norm_eps: Epsilon value for layer normalization
133
133
:param norm_first: if ``True`` will perform normalization before other att and ff operations, otherwise after
134
134
:param norm: normalization function
@@ -143,7 +143,7 @@ def __init__(self, out_dim: nn.Dim, *,
143
143
self .norm = norm
144
144
self .norm_first = norm_first
145
145
self .norm_eps = norm_eps
146
- self .activation = activation
146
+ self .activation = ff_activation
147
147
self .dropout = dropout
148
148
149
149
@nn .scoped
@@ -262,9 +262,9 @@ def __init__(self,
262
262
num_encoder_layers : int = 6 ,
263
263
num_decoder_layers : int = 6 ,
264
264
ff_dim : nn .Dim = nn .NotSpecified ,
265
+ ff_activation : Callable [[nn .Tensor ], nn .Tensor ] = nn .relu ,
265
266
dropout : float = 0.1 ,
266
267
att_dropout : float = 0.1 ,
267
- activation : Callable [[nn .Tensor ], nn .Tensor ] = nn .relu ,
268
268
custom_encoder : Optional [Union [TransformerEncoder , Any ]] = None ,
269
269
custom_decoder : Optional [Union [TransformerDecoder , Any ]] = None ,
270
270
custom_encoder_layer : Optional [Union [TransformerEncoderLayer , Any ]] = None ,
@@ -288,9 +288,9 @@ def __init__(self,
288
288
:param num_encoder_layers: Number of encoder layers
289
289
:param num_decoder_layers: Number of decoder layers
290
290
:param ff_dim: dimension of feedforward layer, PyTorch name: dim_feedforward. 4 * out_dim by default.
291
+ :param ff_activation: activation function
291
292
:param dropout: Dropout value, PyTorch name: dropout
292
293
:param att_dropout: dropout value for attention
293
- :param activation: activation function
294
294
:param custom_encoder: Custom Encoder to replace the standard encoder
295
295
:param custom_decoder: Custom Decoder to replace the standard decoder
296
296
:param custom_encoder_layer: Custom Encoder layer to replace the standard layer if custom_encoder and
@@ -319,7 +319,7 @@ def __init__(self,
319
319
enc_self_attention = nn .SelfAttention (
320
320
key_dim_total = out_dim , value_dim_total = out_dim , num_heads = num_heads , att_dropout = att_dropout )
321
321
encoder_layer = TransformerEncoderLayer (
322
- out_dim = out_dim , ff_dim = ff_dim , dropout = dropout , activation = activation , norm_eps = norm_eps , norm = norm ,
322
+ out_dim = out_dim , ff_dim = ff_dim , dropout = dropout , ff_activation = ff_activation , norm_eps = norm_eps , norm = norm ,
323
323
norm_first = norm_first , self_attention = enc_self_attention )
324
324
self .encoder = TransformerEncoder (
325
325
encoder_layer = encoder_layer , num_layers = num_encoder_layers , norm = norm , norm_eps = norm_eps )
@@ -343,7 +343,7 @@ def __init__(self,
343
343
if ff_dim is nn .NotSpecified :
344
344
ff_dim = out_dim * 4
345
345
decoder_layer = TransformerDecoderLayer (
346
- out_dim = out_dim , ff_dim = ff_dim , dropout = dropout , activation = activation , norm_eps = norm_eps , norm = norm ,
346
+ out_dim = out_dim , ff_dim = ff_dim , dropout = dropout , ff_activation = ff_activation , norm_eps = norm_eps , norm = norm ,
347
347
norm_first = norm_first , causal_self_attention = dec_causal_self_attention , enc_dec_attention = enc_dec_attention )
348
348
self .decoder = TransformerDecoder (
349
349
decoder_layer = decoder_layer , num_layers = num_decoder_layers , norm = norm , norm_eps = norm_eps )
0 commit comments