Skip to content

Commit 498d0a9

Browse files
authored
NormLayer, param_shape support dim tags (#841)
Fix #832.
1 parent 1a6e233 commit 498d0a9

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

returnn/tf/layers/basic.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -817,10 +817,10 @@ class NormLayer(_ConcatInputLayer):
817817
"""
818818
layer_class = "norm"
819819

820-
def __init__(self, axes, param_shape="F", scale=True, bias=True, epsilon=1e-6, **kwargs):
820+
def __init__(self, axes, param_shape=NotSpecified, scale=True, bias=True, epsilon=1e-6, **kwargs):
821821
"""
822-
:param str|list[str] axes: axes over which the mean and variance are computed, e.g. "F" or "TF"
823-
:param str|list[str]|tuple[str]|int|list[int]|tuple[int] param_shape: shape of the scale and bias parameters.
822+
:param Dim|str|list[Dim|str] axes: axes over which the mean and variance are computed, e.g. "F" or "TF"
823+
:param Dim|str|list[Dim|str]|tuple[Dim|str] param_shape: shape of the scale and bias parameters.
824824
You can also refer to (static) axes of the input, such as the feature-dim.
825825
This is also the default, i.e. a param-shape of [F], independent of the axes to normalize over.
826826
:param bool scale: add trainable scale parameters
@@ -830,11 +830,20 @@ def __init__(self, axes, param_shape="F", scale=True, bias=True, epsilon=1e-6, *
830830
super(NormLayer, self).__init__(**kwargs)
831831
assert not self.input_data.sparse
832832
x = self.input_data.placeholder
833-
assert isinstance(param_shape, str) # not implemented otherwise yet
834-
param_axes = sorted(self.input_data.get_axes_from_description(param_shape))
835-
param_shape = [self.input_data.batch_shape[axis] for axis in param_axes]
836-
assert all(isinstance(dim, int) for dim in param_shape), "%s: only static param shape allowed" % self
837-
param_bc_shape = [dim if axis in param_axes else 1 for (axis, dim) in enumerate(self.input_data.batch_shape)]
833+
if scale or bias:
834+
if param_shape is NotSpecified:
835+
param_shape = "F"
836+
if isinstance(param_shape, (list, tuple)):
837+
param_axes = [self.input_data.get_axis_from_description(a, allow_int=False) for a in param_shape]
838+
else:
839+
param_axes = self.input_data.get_axis_from_description(param_shape, allow_int=False)
840+
assert sorted(set(param_axes)) == sorted(param_axes), "%s: param_shape %r should be unique" % (self, param_shape)
841+
param_shape = [self.input_data.batch_shape[axis] for axis in param_axes]
842+
assert all(isinstance(dim, int) for dim in param_shape), "%s: only static param shape allowed" % self
843+
param_dim_tags = [self.input_data.dim_tags[axis] for axis in param_axes]
844+
else:
845+
assert param_shape is NotSpecified or not param_shape
846+
param_dim_tags = None
838847
axes = self.input_data.get_axes_from_description(axes)
839848

840849
mean = tf.reduce_mean(x, axis=axes, keepdims=True, name="mean")
@@ -844,11 +853,15 @@ def __init__(self, axes, param_shape="F", scale=True, bias=True, epsilon=1e-6, *
844853
if scale:
845854
with self.var_creation_scope():
846855
scale_param = self.add_param(tf_compat.v1.get_variable("scale", param_shape, initializer=tf.ones_initializer()))
847-
norm_x *= tf.reshape(scale_param, param_bc_shape)
856+
norm_x *= (
857+
Data(name="scale_param", dim_tags=param_dim_tags, placeholder=scale_param)
858+
.copy_compatible_to(self.output).placeholder)
848859
if bias:
849860
with self.var_creation_scope():
850861
bias_param = self.add_param(tf_compat.v1.get_variable("bias", param_shape, initializer=tf.zeros_initializer()))
851-
norm_x += tf.reshape(bias_param, param_bc_shape)
862+
norm_x += (
863+
Data(name="bias_param", dim_tags=param_dim_tags, placeholder=bias_param)
864+
.copy_compatible_to(self.output).placeholder)
852865
self.output.placeholder = norm_x
853866
self.output.size_placeholder = self.input_data.size_placeholder.copy()
854867

0 commit comments

Comments
 (0)