@@ -817,10 +817,10 @@ class NormLayer(_ConcatInputLayer):
817
817
"""
818
818
layer_class = "norm"
819
819
820
- def __init__ (self , axes , param_shape = "F" , scale = True , bias = True , epsilon = 1e-6 , ** kwargs ):
820
+ def __init__ (self , axes , param_shape = NotSpecified , scale = True , bias = True , epsilon = 1e-6 , ** kwargs ):
821
821
"""
822
- :param str|list[str] axes: axes over which the mean and variance are computed, e.g. "F" or "TF"
823
- :param str|list[str]|tuple[ str]|int|list[int]| tuple[int ] param_shape: shape of the scale and bias parameters.
822
+ :param Dim| str|list[Dim| str] axes: axes over which the mean and variance are computed, e.g. "F" or "TF"
823
+ :param Dim| str|list[Dim| str]|tuple[Dim|str ] param_shape: shape of the scale and bias parameters.
824
824
You can also refer to (static) axes of the input, such as the feature-dim.
825
825
This is also the default, i.e. a param-shape of [F], independent of the axes to normalize over.
826
826
:param bool scale: add trainable scale parameters
@@ -830,11 +830,20 @@ def __init__(self, axes, param_shape="F", scale=True, bias=True, epsilon=1e-6, *
830
830
super (NormLayer , self ).__init__ (** kwargs )
831
831
assert not self .input_data .sparse
832
832
x = self .input_data .placeholder
833
- assert isinstance (param_shape , str ) # not implemented otherwise yet
834
- param_axes = sorted (self .input_data .get_axes_from_description (param_shape ))
835
- param_shape = [self .input_data .batch_shape [axis ] for axis in param_axes ]
836
- assert all (isinstance (dim , int ) for dim in param_shape ), "%s: only static param shape allowed" % self
837
- param_bc_shape = [dim if axis in param_axes else 1 for (axis , dim ) in enumerate (self .input_data .batch_shape )]
833
+ if scale or bias :
834
+ if param_shape is NotSpecified :
835
+ param_shape = "F"
836
+ if isinstance (param_shape , (list , tuple )):
837
+ param_axes = [self .input_data .get_axis_from_description (a , allow_int = False ) for a in param_shape ]
838
+ else :
839
+ param_axes = self .input_data .get_axis_from_description (param_shape , allow_int = False )
840
+ assert sorted (set (param_axes )) == sorted (param_axes ), "%s: param_shape %r should be unique" % (self , param_shape )
841
+ param_shape = [self .input_data .batch_shape [axis ] for axis in param_axes ]
842
+ assert all (isinstance (dim , int ) for dim in param_shape ), "%s: only static param shape allowed" % self
843
+ param_dim_tags = [self .input_data .dim_tags [axis ] for axis in param_axes ]
844
+ else :
845
+ assert param_shape is NotSpecified or not param_shape
846
+ param_dim_tags = None
838
847
axes = self .input_data .get_axes_from_description (axes )
839
848
840
849
mean = tf .reduce_mean (x , axis = axes , keepdims = True , name = "mean" )
@@ -844,11 +853,15 @@ def __init__(self, axes, param_shape="F", scale=True, bias=True, epsilon=1e-6, *
844
853
if scale :
845
854
with self .var_creation_scope ():
846
855
scale_param = self .add_param (tf_compat .v1 .get_variable ("scale" , param_shape , initializer = tf .ones_initializer ()))
847
- norm_x *= tf .reshape (scale_param , param_bc_shape )
856
+ norm_x *= (
857
+ Data (name = "scale_param" , dim_tags = param_dim_tags , placeholder = scale_param )
858
+ .copy_compatible_to (self .output ).placeholder )
848
859
if bias :
849
860
with self .var_creation_scope ():
850
861
bias_param = self .add_param (tf_compat .v1 .get_variable ("bias" , param_shape , initializer = tf .zeros_initializer ()))
851
- norm_x += tf .reshape (bias_param , param_bc_shape )
862
+ norm_x += (
863
+ Data (name = "bias_param" , dim_tags = param_dim_tags , placeholder = bias_param )
864
+ .copy_compatible_to (self .output ).placeholder )
852
865
self .output .placeholder = norm_x
853
866
self .output .size_placeholder = self .input_data .size_placeholder .copy ()
854
867
0 commit comments