Skip to content

Commit f4ecd97

Browse files
committed
consistency
1 parent bbda5e4 commit f4ecd97

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

nn/transformer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@ class TransformerEncoderLayer(nn.Module):
1919
def __init__(self, out_dim: nn.Dim, *,
2020
self_attention: Union[nn.SelfAttention, Any],
2121
ff_dim: nn.Dim,
22+
ff_activation: Callable[[nn.Tensor], nn.Tensor] = nn.relu,
2223
dropout: float = 0.1,
23-
activation: Callable[[nn.Tensor], nn.Tensor] = nn.relu,
2424
norm_eps: float = 1e-6,
2525
norm_first: bool = True,
2626
norm=nn.layer_norm) -> None:
2727
"""
2828
:param out_dim: output dimension, PyTorch name: d_model
2929
:param self_attention: module which does self attention
3030
:param ff_dim: dimension of feedforward layer, PyTorch name: dim_feedforward
31+
:param ff_activation: activation function
3132
:param dropout: Dropout value, PyTorch name: dropout
32-
:param activation: activation function
3333
:param norm_eps: Epsilon value for layer normalization
3434
:param norm_first: if ``True`` will perform normalization before other att and ff operations, otherwise after
3535
:param norm: normalization function
@@ -39,7 +39,7 @@ def __init__(self, out_dim: nn.Dim, *,
3939

4040
self.linear_ff = nn.Linear(ff_dim)
4141
self.linear_out = nn.Linear(out_dim)
42-
self.activation = activation
42+
self.activation = ff_activation
4343
self.norm_first = norm_first
4444
self.norm_eps = norm_eps
4545
self.norm = norm
@@ -117,8 +117,8 @@ def __init__(self, out_dim: nn.Dim, *,
117117
enc_dec_attention: nn.AttentionFunc,
118118
causal_self_attention: Union[nn.CausalSelfAttention, Any],
119119
ff_dim: nn.Dim,
120+
ff_activation: Callable[[nn.Tensor], nn.Tensor] = nn.relu,
120121
dropout: float = 0.1,
121-
activation: Callable[[nn.Tensor], nn.Tensor] = nn.relu,
122122
norm_eps: float = 1e-6,
123123
norm_first: bool = True,
124124
norm=nn.layer_norm):
@@ -127,8 +127,8 @@ def __init__(self, out_dim: nn.Dim, *,
127127
:param enc_dec_attention: module or func which does encoder decoder attention
128128
:param causal_self_attention: module or func which does causal self attention
129129
:param ff_dim: dimension of feedforward layer, PyTorch name: dim_feedforward
130+
:param ff_activation: activation function
130131
:param dropout: Dropout value
131-
:param activation: activation function
132132
:param norm_eps: Epsilon value for layer normalization
133133
:param norm_first: if ``True`` will perform normalization before other att and ff operations, otherwise after
134134
:param norm: normalization function
@@ -143,7 +143,7 @@ def __init__(self, out_dim: nn.Dim, *,
143143
self.norm = norm
144144
self.norm_first = norm_first
145145
self.norm_eps = norm_eps
146-
self.activation = activation
146+
self.activation = ff_activation
147147
self.dropout = dropout
148148

149149
@nn.scoped
@@ -262,9 +262,9 @@ def __init__(self,
262262
num_encoder_layers: int = 6,
263263
num_decoder_layers: int = 6,
264264
ff_dim: nn.Dim = nn.NotSpecified,
265+
ff_activation: Callable[[nn.Tensor], nn.Tensor] = nn.relu,
265266
dropout: float = 0.1,
266267
att_dropout: float = 0.1,
267-
activation: Callable[[nn.Tensor], nn.Tensor] = nn.relu,
268268
custom_encoder: Optional[Union[TransformerEncoder, Any]] = None,
269269
custom_decoder: Optional[Union[TransformerDecoder, Any]] = None,
270270
custom_encoder_layer: Optional[Union[TransformerEncoderLayer, Any]] = None,
@@ -288,9 +288,9 @@ def __init__(self,
288288
:param num_encoder_layers: Number of encoder layers
289289
:param num_decoder_layers: Number of decoder layers
290290
:param ff_dim: dimension of feedforward layer, PyTorch name: dim_feedforward. 4 * out_dim by default.
291+
:param ff_activation: activation function
291292
:param dropout: Dropout value, PyTorch name: dropout
292293
:param att_dropout: dropout value for attention
293-
:param activation: activation function
294294
:param custom_encoder: Custom Encoder to replace the standard encoder
295295
:param custom_decoder: Custom Decoder to replace the standard decoder
296296
:param custom_encoder_layer: Custom Encoder layer to replace the standard layer if custom_encoder and
@@ -319,7 +319,7 @@ def __init__(self,
319319
enc_self_attention = nn.SelfAttention(
320320
key_dim_total=out_dim, value_dim_total=out_dim, num_heads=num_heads, att_dropout=att_dropout)
321321
encoder_layer = TransformerEncoderLayer(
322-
out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, activation=activation, norm_eps=norm_eps, norm=norm,
322+
out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, ff_activation=ff_activation, norm_eps=norm_eps, norm=norm,
323323
norm_first=norm_first, self_attention=enc_self_attention)
324324
self.encoder = TransformerEncoder(
325325
encoder_layer=encoder_layer, num_layers=num_encoder_layers, norm=norm, norm_eps=norm_eps)
@@ -343,7 +343,7 @@ def __init__(self,
343343
if ff_dim is nn.NotSpecified:
344344
ff_dim = out_dim * 4
345345
decoder_layer = TransformerDecoderLayer(
346-
out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, activation=activation, norm_eps=norm_eps, norm=norm,
346+
out_dim=out_dim, ff_dim=ff_dim, dropout=dropout, ff_activation=ff_activation, norm_eps=norm_eps, norm=norm,
347347
norm_first=norm_first, causal_self_attention=dec_causal_self_attention, enc_dec_attention=enc_dec_attention)
348348
self.decoder = TransformerDecoder(
349349
decoder_layer=decoder_layer, num_layers=num_decoder_layers, norm=norm, norm_eps=norm_eps)

0 commit comments

Comments
 (0)