@@ -754,15 +754,26 @@ class LayerNormLayer(_ConcatInputLayer):
754
754
"""
755
755
layer_class = "layer_norm"
756
756
757
- def __init__ (self , epsilon = 1e-6 , ** kwargs ):
757
+ def __init__ (self , in_dim = None , out_dim = None , epsilon = 1e-6 , ** kwargs ):
758
758
"""
759
+ :param Dim|None in_dim:
760
+ :param Dim|None out_dim:
759
761
:param float epsilon:
760
762
"""
761
763
super (LayerNormLayer , self ).__init__ (** kwargs )
762
764
assert not self .input_data .sparse
763
765
x = self .input_data .placeholder
764
- dim = self .input_data .dim
765
- axis = self .input_data .feature_dim_axis
766
+ if not in_dim and out_dim :
767
+ in_dim = out_dim
768
+ if in_dim :
769
+ if out_dim :
770
+ assert in_dim == out_dim
771
+ assert isinstance (in_dim , Dim )
772
+ axis = self .input_data .get_axis_from_description (in_dim )
773
+ else :
774
+ axis = self .input_data .feature_dim_axis
775
+ dim = self .input_data .batch_shape [axis ]
776
+ assert dim is not None , "%s: in_dim %i must be static in input %s" % (self , in_dim or axis , self .input_data )
766
777
with self .var_creation_scope ():
767
778
scale = self .add_param (tf_compat .v1 .get_variable ("scale" , [dim ], initializer = tf .ones_initializer ()))
768
779
bias = self .add_param (tf_compat .v1 .get_variable ("bias" , [dim ], initializer = tf .zeros_initializer ()))
0 commit comments