Skip to content

DotLayer, single reduce argument #837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 7, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6178,13 +6178,16 @@ class DotLayer(LayerBase):
"""
layer_class = "dot"

def __init__(self, red1=NotSpecified, red2=NotSpecified, var1=NotSpecified, var2=NotSpecified,
def __init__(self,
reduce=NotSpecified,
red1=NotSpecified, red2=NotSpecified, var1=NotSpecified, var2=NotSpecified,
add_var2_if_empty=NotSpecified, debug=False, **kwargs):
"""
:param str|Dim|tuple[str|DimensionTag]|list[str|DimensionTag] red1: reduce axes of first source
:param str|DimensionTag|tuple[str|DimensionTag]|list[str|DimensionTag] red2: reduce axes of second source
:param str|DimensionTag|tuple[str|DimensionTag]|list[str|DimensionTag]|None var1: var axes of first source
:param str|DimensionTag|tuple[str|DimensionTag]|list[str|DimensionTag]|None var2: var axes of second source
:param str|Dim|tuple[str|Dim]|list[str|Dim] reduce: reduce axes of both sources
:param str|Dim|tuple[str|Dim]|list[str|Dim] red1: reduce axes of first source
:param str|Dim|tuple[str|Dim]|list[str|Dim] red2: reduce axes of second source
:param str|Dim|tuple[str|Dim]|list[str|Dim]|None var1: var axes of first source
:param str|Dim|tuple[str|Dim]|list[str|Dim]|None var2: var axes of second source
:param bool add_var2_if_empty: if var2=None, add dim=1 at the end
:param bool debug: will print debug shapes, etc.

Expand All @@ -6196,6 +6199,9 @@ def __init__(self, red1=NotSpecified, red2=NotSpecified, var1=NotSpecified, var2
from returnn.util import BehaviorVersion
from returnn.tf.util.basic import prod, get_shape, get_padding_info_dict_ref, mask_dyn_seq_len_nd
super(DotLayer, self).__init__(**kwargs)
if reduce is not NotSpecified:
assert red1 is NotSpecified and red2 is NotSpecified
red1 = red2 = reduce
BehaviorVersion.require(
condition=all(not isinstance(a, int) for a in (red1, red2, var1, var2)),
message="DotLayer: Axes must be referenced by tag or special specified, not by int.",
Expand Down Expand Up @@ -6398,33 +6404,60 @@ def _add(dims, val, d_key):
_add(dims2, var2, "var2")

@classmethod
def get_out_data_from_opts(cls, name, sources, red1=-1, red2=-2, var1=-2, var2=-1,
def get_out_data_from_opts(cls, name, sources,
reduce=NotSpecified,
red1=NotSpecified, red2=NotSpecified, var1=NotSpecified, var2=NotSpecified,
add_var2_if_empty=NotSpecified, **kwargs):
"""
:param str name:
:param list[LayerBase] sources:
:param str|int|tuple[str|int]|list[str|int] red1: reduce axes of first source
:param str|int|tuple[str|int]|list[str|int] red2: reduce axes of second source
:param str|int|tuple[str|int]|list[str|int]|None var1: var axes of first source
:param str|int|tuple[str|int]|list[str|int]|None var2: var axes of second source
:param str|Dim|tuple[str|Dim]|list[str|Dim] reduce: reduce axes of both sources
:param str|Dim|tuple[str|Dim]|list[str|Dim] red1: reduce axes of first source
:param str|Dim|tuple[str|Dim]|list[str|Dim] red2: reduce axes of second source
:param str|Dim|tuple[str|Dim]|list[str|Dim]|None var1: var axes of first source
:param str|Dim|tuple[str|Dim]|list[str|Dim]|None var2: var axes of second source
:param bool add_var2_if_empty:
:rtype: Data
"""
from returnn.util import BehaviorVersion
from ..util.data import BatchInfo
assert len(sources) == 2, "dot-layer %r: needs exactly two sources" % (name,)
# See __init__.
# As usual, do as minimal error checking as possible here.
if add_var2_if_empty is NotSpecified:
add_var2_if_empty = True if BehaviorVersion.get() < 3 else False
if reduce is not NotSpecified:
assert red1 is NotSpecified and red2 is NotSpecified
red1 = red2 = reduce
BehaviorVersion.require(
condition=all(not isinstance(a, int) for a in (red1, red2, var1, var2)),
message="DotLayer: Axes must be referenced by tag or special specified, not by int.",
version=3)
BehaviorVersion.require(
condition=all(a is not NotSpecified for a in (red1, red2, var1, var2)),
message="DotLayer: Axes must be specified explicitly. There is no default.",
version=3)
BehaviorVersion.require(
condition=add_var2_if_empty is NotSpecified or not add_var2_if_empty,
message="DotLayer: add_var2_if_empty not allowed",
version=3)
if BehaviorVersion.get() < 3:
# Earlier defaults: red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True.
red1 = -1 if red1 is NotSpecified else red1
red2 = -2 if red2 is NotSpecified else red2
var1 = -2 if var1 is NotSpecified else var1
var2 = -1 if var2 is NotSpecified else var2
add_var2_if_empty = True if add_var2_if_empty is NotSpecified else add_var2_if_empty
axis_desc_allow_int = True
else:
# add_var2_if_empty not supported anymore.
add_var2_if_empty = False
axis_desc_allow_int = False
a_out = sources[0].output.copy()
a_reduce_axes = a_out.get_axes_from_description(red1)
a_reduce_axes = a_out.get_axes_from_description(red1, allow_int=axis_desc_allow_int)
b_out = sources[1].output.copy()
assert not a_out.beam or not b_out.beam or a_out.beam == b_out.beam
b_reduce_axes = b_out.get_axes_from_description(red2)
b_reduce_axes = b_out.get_axes_from_description(red2, allow_int=axis_desc_allow_int)
assert a_reduce_axes and b_reduce_axes, "%s: sources %r, red1 %r, red2 %r" % (name, sources, red1, red2)
a_var_axes = a_out.get_axes_from_description(var1)
b_var_axes = b_out.get_axes_from_description(var2)
a_var_axes = a_out.get_axes_from_description(var1, allow_int=axis_desc_allow_int)
b_var_axes = b_out.get_axes_from_description(var2, allow_int=axis_desc_allow_int)
assert not set(a_reduce_axes).intersection(a_var_axes)
assert not set(b_reduce_axes).intersection(b_var_axes)
a_rem_axes = [i for i in range(a_out.batch_ndim) if i not in a_var_axes + a_reduce_axes]
Expand Down