Skip to content

Commit 61a7f70

Browse files
committed
DotLayer, mask dyn axes when needed
Fix #629
1 parent a4d85ee commit 61a7f70

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

returnn/tf/layers/basic.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -5375,7 +5375,7 @@ def __init__(self, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, d
53755375
:param bool add_var2_if_empty: if var2=None, add dim=1 at the end
53765376
:param bool debug: will print debug shapes, etc.
53775377
"""
5378-
from returnn.tf.util.basic import prod, get_shape
5378+
from returnn.tf.util.basic import prod, get_shape, get_padding_info_dict_ref, mask_dyn_seq_len_nd
53795379
super(DotLayer, self).__init__(**kwargs)
53805380
a_out = self.sources[0].output.copy()
53815381
b_out = self.sources[1].output.copy()
@@ -5427,6 +5427,29 @@ def __init__(self, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, d
54275427
for (d1, d2, i1, i2) in zip(a_reduce_dims, b_reduce_dims, a_reduce_axes, b_reduce_axes)])
54285428
a_var_dim = prod(a_var_dims)
54295429
b_var_dim = prod(b_var_dims)
5430+
a_reduce_dyn_axes = [i for i in a_reduce_axes if a_out.batch_shape[i] is None]
5431+
b_reduce_dyn_axes = [i for i in b_reduce_axes if b_out.batch_shape[i] is None]
5432+
assert len(a_reduce_dyn_axes) == len(b_reduce_dyn_axes)
5433+
if a_reduce_dyn_axes:
5434+
a_pad, b_pad = get_padding_info_dict_ref(a), get_padding_info_dict_ref(b)
5435+
a_pad_values = [a_pad.get(a_out.dim_tags[i], None) for i in a_reduce_dyn_axes]
5436+
b_pad_values = [b_pad.get(b_out.dim_tags[i], None) for i in b_reduce_dyn_axes]
5437+
if set(a_pad_values) == {0}:
5438+
self._info_reduce_mask = "source-0-already-masked" # it's already masked as needed
5439+
elif set(b_pad_values) == {0}:
5440+
self._info_reduce_mask = "source-1-already-masked" # it's already masked as needed
5441+
else:
5442+
# We need to apply a mask.
5443+
# We don't need it on both a and b. We can either apply it on a or on b.
5444+
# Use some very simple heuristic where the mask is maybe cheaper.
5445+
if len(a_shape) < len(b_shape):
5446+
a = mask_dyn_seq_len_nd(a_out, pad_value=0, axes=a_reduce_dyn_axes)
5447+
self._info_reduce_mask = "mask-source-0"
5448+
else:
5449+
b = mask_dyn_seq_len_nd(b_out, pad_value=0, axes=b_reduce_dyn_axes)
5450+
self._info_reduce_mask = "mask-source-1"
5451+
else:
5452+
self._info_reduce_mask = "none-dynamic"
54305453
a_reduce_dim = prod(a_reduce_dims)
54315454
b_reduce_dim = prod(b_reduce_dims)
54325455
if debug:

0 commit comments

Comments
 (0)