Skip to content

Commit e432a28

Browse files
committed
add test and numpy/tf implementation for slice_nd
1 parent 42566fe commit e432a28

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

returnn/tf/util/basic.py

+27
Original file line numberDiff line numberDiff line change
@@ -3631,6 +3631,33 @@ def windowed_nd(source, window_size, window_left=None, window_right=None,
36313631
return final
36323632

36333633

3634+
def slice_nd2(x, start, size):
3635+
"""
3636+
This is a more generic slice function, where arbitrary many common axis between x and start are allowed.
3637+
Here we assume that x and start have their axis layed in the same order.
3638+
3639+
:param tf.Tensor x: shape (B, T1, ..., Tn, D)
3640+
:param tf.Tensor start: shape (B,T1 .., Tn-1), int32 which automatically indicates n as the slice-axis
3641+
:param int axis: in the range [0..n-1]
3642+
:param int|tf.Tensor size: scalar
3643+
:return: ret[b, t1, .., tn-1, 0..size, :] = x[b, t1, .., tn-1, start[B, t1, .., tn-1]+0..size, :]
3644+
In case the slices go out of bounds of the slice dimension and we will pad with zeros.
3645+
:rtype: tf.Tensor
3646+
"""
3647+
with tf.name_scope("slice_nd"):
3648+
shape = x.shape
3649+
len_common_dims = len(start.shape) # nr of common dims
3650+
slice_dim = shape[len_common_dims] # dim of axis to be sliced
3651+
assert size < slice_dim, "Slice size cannot be bigger than the dimension to be sliced."
3652+
# Create indexes for the slices where slice_idx[B,T1 .., Tn-1] = start[B,T1 .., Tn-1] + range(size)
3653+
slice_idx = tf.tile(tf.expand_dims(start, -1), [1] * len_common_dims + [size]) + tf.range(size) # (B,T1 .., Tn-1, size)
3654+
mask = tf.logical_or(tf.greater(slice_idx, slice_dim - 1), tf.less(slice_idx, 0)) # (B,T1 .., Tn-1, size)
3655+
slice_idx = tf.clip_by_value(slice_idx, 0, slice_dim - 1) # cliped slice idx
3656+
res = tf.gather(x, slice_idx, axis=len_common_dims, batch_dims=len_common_dims)
3657+
res = where_bc(mask, tf.zeros_like(res), res) # zero-padding
3658+
return res
3659+
3660+
36343661
def slice_nd(x, start, size):
36353662
"""
36363663
:param tf.Tensor x: shape (B, T, ...)

tests/test_TFUtil.py

+61-2
Original file line numberDiff line numberDiff line change
@@ -1637,6 +1637,8 @@ def test_windowed_nd_big():
16371637

16381638

16391639
def naive_slice_nd(x, start, size):
1640+
# old implementation, check out naive_slice_nd2
1641+
16401642
slices_shape = [x.shape[0], size] + list(x.shape)[2:]
16411643
ys = numpy.zeros(shape=slices_shape)
16421644
for i in range(len(start)):
@@ -1653,6 +1655,63 @@ def naive_slice_nd(x, start, size):
16531655
return ys
16541656

16551657

1658+
def naive_slice_nd2(x, start, size):
1659+
# Assuming that x: [B, T1, T2, .., Tn, D] and start: [B, T1, .., Tn-1]
1660+
# i.e. the dimensions of x and start are ordered accordingly.
1661+
# (Otherwise we should require the slice axis too.)
1662+
1663+
len_common_dims = len(start.shape)
1664+
slice_shape = (size,) + x.shape[len_common_dims+1:]
1665+
result_shape = start.shape[0:len_common_dims] + slice_shape # shape of output
1666+
result = numpy.zeros(result_shape)
1667+
1668+
slice_axis_dim = x.shape[len_common_dims] # dim of axis being sliced
1669+
for index, start_position in numpy.ndenumerate(start):
1670+
end_position = min(start_position+size, slice_axis_dim) # padding required
1671+
1672+
# no padding
1673+
padding = ((0,0),)
1674+
for i in range(1, len(slice_shape)):
1675+
padding += ((0, 0),)
1676+
1677+
# if required replace the first padding tuple, which corresponds to the slice axis
1678+
if end_position < start_position+size:
1679+
padding = ((0,size - end_position + start_position),) + padding[1:]
1680+
result[index] = numpy.pad(x[index][start_position:end_position], padding, mode='constant', constant_values=0)
1681+
return result
1682+
1683+
1684+
def test_slice_nd_multi_dim():
1685+
n_batch = 2
1686+
n_time_1 = 2
1687+
n_time_2 = 3 # slice axis
1688+
n_dim = 2
1689+
size = 2
1690+
source = numpy.arange(24, dtype=numpy.float32).reshape(n_batch, n_time_1, n_time_2, n_dim).astype("float32")
1691+
start = numpy.array([[0,1],[1,2]]).astype("int32")
1692+
naive = naive_slice_nd2(source, start, size)
1693+
source_tf = tf.constant(source)
1694+
real = slice_nd2(source_tf, start=start, size=size).eval()
1695+
print("source:")
1696+
print(source)
1697+
print("naive:")
1698+
print(naive)
1699+
print("real:")
1700+
print(real)
1701+
expected_output = numpy.array(
1702+
[[[[0, 1],
1703+
[2, 3]],
1704+
[[8, 9],
1705+
[10, 11]]],
1706+
1707+
[[[14, 15],
1708+
[16, 17]],
1709+
[[22, 23],
1710+
[0, 0]]]]) # padding
1711+
numpy.testing.assert_almost_equal(naive, expected_output)
1712+
numpy.testing.assert_almost_equal(real, expected_output)
1713+
1714+
16561715
def test_slice_nd_small():
16571716
n_batch = 3
16581717
n_time = 4
@@ -1662,7 +1721,7 @@ def test_slice_nd_small():
16621721
source = numpy.arange(1, n_batch*n_time*n_dim + 1, dtype=numpy.float32).reshape(n_batch, n_time, n_dim).astype("float32")
16631722
source_tf = tf.constant(source)
16641723
naive = naive_slice_nd(source, start, size)
1665-
real = slice_nd(source_tf, start=start, size=size).eval()
1724+
real = slice_nd2(source_tf, start=start, size=size).eval()
16661725
print("source:")
16671726
print(source)
16681727
print("naive:")
@@ -1682,7 +1741,7 @@ def test_slice_nd_big():
16821741
source = numpy.random.random((n_batch, n_time, n_dim)).astype("float32")
16831742
source_tf = tf.constant(source)
16841743
naive = naive_slice_nd(source, start, size)
1685-
real = slice_nd(source_tf, start=start, size=size).eval()
1744+
real = slice_nd2(source_tf, start=start, size=size).eval()
16861745
print("source:")
16871746
print(source)
16881747
print("naive:")

0 commit comments

Comments
 (0)