Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,16 +817,21 @@ class NormLayer(_ConcatInputLayer):
"""
layer_class = "norm"

def __init__(self, axes, param_shape=NotSpecified, scale=True, bias=True, epsilon=1e-6, **kwargs):
def __init__(self, axis=NotSpecified, axes=NotSpecified,
param_shape=NotSpecified, scale=True, bias=True, epsilon=1e-6, **kwargs):
"""
:param Dim|str|list[Dim|str] axes: axes over which the mean and variance are computed, e.g. "F" or "TF"
:param Dim|str|list[Dim|str] axis: axis or axes over which the mean and variance are computed, e.g. "F" or "TF"
:param Dim|str|list[Dim|str] axes: axis or axes over which the mean and variance are computed, e.g. "F" or "TF"
:param Dim|str|list[Dim|str]|tuple[Dim|str] param_shape: shape of the scale and bias parameters.
You can also refer to (static) axes of the input, such as the feature-dim.
This is also the default, i.e. a param-shape of [F], independent of the axes to normalize over.
:param bool scale: add trainable scale parameters
:param bool bias: add trainable bias parameters
:param float epsilon: epsilon for numerical stability
"""
if axis is not NotSpecified:
assert axes is NotSpecified
axes = axis
super(NormLayer, self).__init__(**kwargs)
assert not self.input_data.sparse
x = self.input_data.placeholder
Expand Down