Skip to content

Commit 1292e80

Browse files
committed
update comments
1 parent 1cb1ae4 commit 1292e80

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

returnn/tf/layers/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,8 @@ def __init__(self, start, size, min_size=None, **kwargs):
908908

909909
if seq_lens is not None:
910910
mask = tf.greater_equal(
911-
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,Tn)
912-
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,1,1,..,Tn,1)
911+
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn)
912+
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,T1,..,Tn,1,..)
913913
slices = where_bc(mask, tf.zeros_like(slices), slices)
914914

915915
self.output.size_placeholder = x.size_placeholder.copy()

0 commit comments

Comments
 (0)