@@ -875,7 +875,7 @@ def __init__(self, start, size, min_size=None, **kwargs):
875
875
:param int|None min_size: if size is None, but we want to have a min-size, set this
876
876
"""
877
877
super (SliceNdLayer , self ).__init__ (** kwargs )
878
- from returnn .tf .util .basic import slice_nd , DimensionTag
878
+ from returnn .tf .util .basic import slice_nd , where_bc , expand_multiple_dims , DimensionTag
879
879
assert start .output .have_batch_axis () and self .input_data .have_batch_axis ()
880
880
self .start = start
881
881
@@ -906,6 +906,12 @@ def __init__(self, start, size, min_size=None, **kwargs):
906
906
self .size = size
907
907
slices = slice_nd (x .placeholder , start = tf .cast (start , tf .int32 ), size = size ) # (B,size, ...)
908
908
909
+ if seq_lens is not None :
910
+ mask = tf .greater_equal (
911
+ tf .range (size )[None , :] + tf .expand_dims (start , axis = - 1 ), seq_lens [:, None ]) # (B,T1,..,Tn)
912
+ mask = expand_multiple_dims (mask , list (range (slice_axis + 2 , x .batch_ndim ))) # (B,T1,..,Tn,1,..)
913
+ slices = where_bc (mask , tf .zeros_like (slices ), slices )
914
+
909
915
self .output .size_placeholder = x .size_placeholder .copy ()
910
916
if isinstance (size , tf .Tensor ):
911
917
self .output .size_placeholder [slice_axis ] = size
0 commit comments