Skip to content

Commit b31b0a0

Browse files
committed
LayerNormLayer, handle in_dim, out_dim
Fix #834.
1 parent 498d0a9 commit b31b0a0

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

returnn/tf/layers/basic.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -754,15 +754,26 @@ class LayerNormLayer(_ConcatInputLayer):
754754
"""
755755
layer_class = "layer_norm"
756756

757-
def __init__(self, epsilon=1e-6, **kwargs):
757+
def __init__(self, in_dim=None, out_dim=None, epsilon=1e-6, **kwargs):
758758
"""
759+
:param Dim|None in_dim:
760+
:param Dim|None out_dim:
759761
:param float epsilon:
760762
"""
761763
super(LayerNormLayer, self).__init__(**kwargs)
762764
assert not self.input_data.sparse
763765
x = self.input_data.placeholder
764-
dim = self.input_data.dim
765-
axis = self.input_data.feature_dim_axis
766+
if not in_dim and out_dim:
767+
in_dim = out_dim
768+
if in_dim:
769+
if out_dim:
770+
assert in_dim == out_dim
771+
assert isinstance(in_dim, Dim)
772+
axis = self.input_data.get_axis_from_description(in_dim)
773+
else:
774+
axis = self.input_data.feature_dim_axis
775+
dim = self.input_data.batch_shape[axis]
776+
assert dim is not None, "%s: in_dim %i must be static in input %s" % (self, in_dim or axis, self.input_data)
766777
with self.var_creation_scope():
767778
scale = self.add_param(tf_compat.v1.get_variable("scale", [dim], initializer=tf.ones_initializer()))
768779
bias = self.add_param(tf_compat.v1.get_variable("bias", [dim], initializer=tf.zeros_initializer()))

0 commit comments

Comments
 (0)