Skip to content

Commit c3c0838

Browse files
committed
NormLayer, axis option
Fix #831.
1 parent 498d0a9 commit c3c0838

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

returnn/tf/layers/basic.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,16 +817,21 @@ class NormLayer(_ConcatInputLayer):
817817
"""
818818
layer_class = "norm"
819819

820-
def __init__(self, axes, param_shape=NotSpecified, scale=True, bias=True, epsilon=1e-6, **kwargs):
820+
def __init__(self, axis=NotSpecified, axes=NotSpecified,
821+
param_shape=NotSpecified, scale=True, bias=True, epsilon=1e-6, **kwargs):
821822
"""
822-
:param Dim|str|list[Dim|str] axes: axes over which the mean and variance are computed, e.g. "F" or "TF"
823+
:param Dim|str|list[Dim|str] axis: axis or axes over which the mean and variance are computed, e.g. "F" or "TF"
824+
:param Dim|str|list[Dim|str] axes: axis or axes over which the mean and variance are computed, e.g. "F" or "TF"
823825
:param Dim|str|list[Dim|str]|tuple[Dim|str] param_shape: shape of the scale and bias parameters.
824826
You can also refer to (static) axes of the input, such as the feature-dim.
825827
This is also the default, i.e. a param-shape of [F], independent of the axes to normalize over.
826828
:param bool scale: add trainable scale parameters
827829
:param bool bias: add trainable bias parameters
828830
:param float epsilon: epsilon for numerical stability
829831
"""
832+
if axis is not NotSpecified:
833+
assert axes is NotSpecified
834+
axes = axis
830835
super(NormLayer, self).__init__(**kwargs)
831836
assert not self.input_data.sparse
832837
x = self.input_data.placeholder

0 commit comments

Comments
 (0)