Skip to content

Commit 4749b27

Browse files
committed
consider dynamic seq_lens in slice_nd
1 parent f66ed78 commit 4749b27

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

returnn/tf/layers/basic.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def __init__(self, start, size, min_size=None, **kwargs):
875875
:param int|None min_size: if size is None, but we want to have a min-size, set this
876876
"""
877877
super(SliceNdLayer, self).__init__(**kwargs)
878-
from returnn.tf.util.basic import slice_nd, DimensionTag
878+
from returnn.tf.util.basic import slice_nd, where_bc, expand_multiple_dims, DimensionTag
879879
assert start.output.have_batch_axis() and self.input_data.have_batch_axis()
880880
self.start = start
881881

@@ -906,6 +906,12 @@ def __init__(self, start, size, min_size=None, **kwargs):
906906
self.size = size
907907
slices = slice_nd(x.placeholder, start=tf.cast(start, tf.int32), size=size) # (B,size, ...)
908908

909+
if seq_lens is not None:
910+
mask = tf.greater_equal(
911+
tf.range(size)[None, :] + tf.expand_dims(start, axis=-1), seq_lens[:, None]) # (B,T1,..,Tn)
912+
mask = expand_multiple_dims(mask, list(range(slice_axis + 2, x.batch_ndim))) # (B,T1,..,Tn,1,..)
913+
slices = where_bc(mask, tf.zeros_like(slices), slices)
914+
909915
self.output.size_placeholder = x.size_placeholder.copy()
910916
if isinstance(size, tf.Tensor):
911917
self.output.size_placeholder[slice_axis] = size

0 commit comments

Comments
 (0)