Skip to content

Commit 0c8e669

Browse files
committed
correct mask for seq_lens
1 parent 9ca5a6d commit 0c8e669

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

returnn/tf/layers/basic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -909,9 +909,7 @@ def __init__(self, start, size, min_size=None, **kwargs):
909909
if seq_lens is not None:
910910
mask = tf.greater_equal(
911911
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,Tn)
912-
mask = expand_multiple_dims(
913-
mask,
914-
list(range(1, slice_axis + 1)) + list(range(slice_axis + 2, x.batch_ndim))) # (B,1,1,..,Tn,1)
912+
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,1,1,..,Tn,1)
915913
slices = where_bc(mask, tf.zeros_like(slices), slices)
916914

917915
self.output.size_placeholder = x.size_placeholder.copy()

0 commit comments

Comments
 (0)