Skip to content

Commit 1a6e233

Browse files
authored
MathNormLayer, axis option (#840)
Fix #833.
1 parent 85f7420 commit 1a6e233

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

returnn/tf/layers/basic.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -868,26 +868,34 @@ class MathNormLayer(_ConcatInputLayer):
868868
"""
869869
layer_class = "math_norm"
870870

871-
def __init__(self, p, axes, keep_dims=False, **kwargs):
871+
def __init__(self, p, axis=NotSpecified, axes=NotSpecified, keep_dims=False, **kwargs):
872872
"""
873873
:param int|float p:
874-
:param str|list[str] axes:
874+
:param Dim|str|list[Dim|str] axis:
875+
:param Dim|str|list[Dim|str] axes:
875876
:param bool keep_dims:
876877
"""
878+
if axis is not NotSpecified:
879+
assert axes is NotSpecified
880+
axes = axis
877881
super(MathNormLayer, self).__init__(**kwargs)
878882
x = self.input_data.copy()
879883
x.placeholder = tf.abs(x.placeholder) ** p
880884
self.output.placeholder = ReduceLayer.reduce(x, mode="sum", axes=axes, keep_dims=keep_dims) ** (1. / p)
881885

882886
@classmethod
883-
def get_out_data_from_opts(cls, name, sources, axes, keep_dims=False, **kwargs):
887+
def get_out_data_from_opts(cls, name, sources, axis=NotSpecified, axes=NotSpecified, keep_dims=False, **kwargs):
884888
"""
885889
:param str name:
886890
:param list[LayerBase] sources:
887-
:param str|list[str] axes:
891+
:param Dim|str|list[Dim|str] axis:
892+
:param Dim|str|list[Dim|str] axes:
888893
:param bool keep_dims:
889894
:rtype: Data
890895
"""
896+
if axis is not NotSpecified:
897+
assert axes is NotSpecified
898+
axes = axis
891899
return ReduceLayer.get_out_data_from_opts(name=name, sources=sources, axes=axes, keep_dims=keep_dims)
892900

893901

0 commit comments

Comments
 (0)