@@ -5375,7 +5375,7 @@ def __init__(self, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, d
5375
5375
:param bool add_var2_if_empty: if var2=None, add dim=1 at the end
5376
5376
:param bool debug: will print debug shapes, etc.
5377
5377
"""
5378
- from returnn .tf .util .basic import prod , get_shape
5378
+ from returnn .tf .util .basic import prod , get_shape , get_padding_info_dict_ref , mask_dyn_seq_len_nd
5379
5379
super (DotLayer , self ).__init__ (** kwargs )
5380
5380
a_out = self .sources [0 ].output .copy ()
5381
5381
b_out = self .sources [1 ].output .copy ()
@@ -5427,6 +5427,29 @@ def __init__(self, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, d
5427
5427
for (d1 , d2 , i1 , i2 ) in zip (a_reduce_dims , b_reduce_dims , a_reduce_axes , b_reduce_axes )])
5428
5428
a_var_dim = prod (a_var_dims )
5429
5429
b_var_dim = prod (b_var_dims )
5430
+ a_reduce_dyn_axes = [i for i in a_reduce_axes if a_out .batch_shape [i ] is None ]
5431
+ b_reduce_dyn_axes = [i for i in b_reduce_axes if b_out .batch_shape [i ] is None ]
5432
+ assert len (a_reduce_dyn_axes ) == len (b_reduce_dyn_axes )
5433
+ if a_reduce_dyn_axes :
5434
+ a_pad , b_pad = get_padding_info_dict_ref (a ), get_padding_info_dict_ref (b )
5435
+ a_pad_values = [a_pad .get (a_out .dim_tags [i ], None ) for i in a_reduce_dyn_axes ]
5436
+ b_pad_values = [b_pad .get (b_out .dim_tags [i ], None ) for i in b_reduce_dyn_axes ]
5437
+ if set (a_pad_values ) == {0 }:
5438
+ self ._info_reduce_mask = "source-0-already-masked" # it's already masked as needed
5439
+ elif set (b_pad_values ) == {0 }:
5440
+ self ._info_reduce_mask = "source-1-already-masked" # it's already masked as needed
5441
+ else :
5442
+ # We need to apply a mask.
5443
+ # We don't need it on both a and b. We can either apply it on a or on b.
5444
+ # Use some very simple heuristic where the mask is maybe cheaper.
5445
+ if len (a_shape ) < len (b_shape ):
5446
+ a = mask_dyn_seq_len_nd (a_out , pad_value = 0 , axes = a_reduce_dyn_axes )
5447
+ self ._info_reduce_mask = "mask-source-0"
5448
+ else :
5449
+ b = mask_dyn_seq_len_nd (b_out , pad_value = 0 , axes = b_reduce_dyn_axes )
5450
+ self ._info_reduce_mask = "mask-source-1"
5451
+ else :
5452
+ self ._info_reduce_mask = "none-dynamic"
5430
5453
a_reduce_dim = prod (a_reduce_dims )
5431
5454
b_reduce_dim = prod (b_reduce_dims )
5432
5455
if debug :
0 commit comments