@@ -84,7 +84,7 @@ class TransformerEncoder(nn.Module):
84
84
def __init__ (self , encoder_layer : Union [TransformerEncoderLayer , Any ], * , num_layers : int ,
85
85
norm = nn .LayerNorm , norm_eps : float = 1e-6 ):
86
86
"""
87
- :param encoder_layer: Encoder layer to be stacked num_layers times
87
+ :param encoder_layer: Encoder layer to be stacked num_layers times (copies of it, no param sharing)
88
88
:param num_layers: Number of layers
89
89
:param norm: normalization function, e.g. nn.LayerNorm()
90
90
:param norm_eps: Epsilon value for layer normalization
@@ -290,9 +290,11 @@ def __init__(self,
290
290
:param custom_encoder: Custom Encoder to replace the standard encoder
291
291
:param custom_decoder: Custom Decoder to replace the standard decoder
292
292
:param custom_encoder_layer: Custom Encoder layer to replace the standard layer if custom_encoder and
293
- custom_encoder_layer are given custom_encoder will be preferred
293
+ custom_encoder_layer are given custom_encoder will be preferred.
294
+ Copies of it will be made for each layer, so there is no automatic param sharing.
294
295
:param custom_decoder_layer: Custom Decoder layer to replace the standard layer if custom_decoder and
295
296
custom_decoder_layer are given custom_decoder will be preferred
297
+ Copies of it will be made for each layer, so there is no automatic param sharing.
296
298
:param norm_eps: Epsilon value for layer normalization
297
299
:param norm: function for layer normalization
298
300
:param norm_first: if ``True`` will perform normalization before other att and ff operations, otherwise after
0 commit comments