Skip to content

Commit a45b4a5

Browse files
committed
add test and numpy/tf implementation for slice_nd
1 parent 42566fe commit a45b4a5

File tree

3 files changed

+115
-4
lines changed

3 files changed

+115
-4
lines changed

returnn/tf/util/basic.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3631,6 +3631,33 @@ def windowed_nd(source, window_size, window_left=None, window_right=None,
36313631
return final
36323632

36333633

3634+
def slice_nd2(x, start, size):
3635+
"""
3636+
This is a more generic slice function, where arbitrary many common axis between x and start are allowed.
3637+
Here we assume that x and start have their axis layed in the same order.
3638+
3639+
:param tf.Tensor x: shape (B, T1, ..., Tn, D)
3640+
:param tf.Tensor start: shape (B,T1 .., Tn-1), int32 which automatically indicates n as the slice-axis
3641+
:param int axis: in the range [0..n-1]
3642+
:param int|tf.Tensor size: scalar
3643+
:return: ret[b, t1, .., tn-1, 0..size, :] = x[b, t1, .., tn-1, start[B, t1, .., tn-1]+0..size, :]
3644+
In case the slices go out of bounds of the slice dimension and we will pad with zeros.
3645+
:rtype: tf.Tensor
3646+
"""
3647+
with tf.name_scope("slice_nd"):
3648+
shape = x.shape
3649+
len_common_dims = len(start.shape) # nr of common dims
3650+
slice_dim = shape[len_common_dims] # dim of axis to be sliced
3651+
assert size < slice_dim, "Slice size cannot be bigger than the dimension to be sliced."
3652+
# Create indexes for the slices where slice_idx[B,T1 .., Tn-1] = start[B,T1 .., Tn-1] + range(size)
3653+
slice_idx = tf.tile(tf.expand_dims(start, -1), [1] * len_common_dims + [size]) + tf.range(size) # (B,T1 .., Tn-1, size)
3654+
mask = tf.logical_or(tf.greater(slice_idx, slice_dim - 1), tf.less(slice_idx, 0)) # (B,T1 .., Tn-1, size)
3655+
slice_idx = tf.clip_by_value(slice_idx, 0, slice_dim - 1) # cliped slice idx
3656+
res = tf.gather(x, slice_idx, axis=len_common_dims, batch_dims=len_common_dims)
3657+
res = where_bc(mask, tf.zeros_like(res), res) # zero-padding
3658+
return res
3659+
3660+
36343661
def slice_nd(x, start, size):
36353662
"""
36363663
:param tf.Tensor x: shape (B, T, ...)

tests/test_TFNetworkRecLayer.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3328,7 +3328,7 @@ def test_rec_subnet_simple_rnn():
33283328
print("rnn_cell also fine.")
33293329

33303330

3331-
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, rtol=1e-4):
3331+
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, from_=None, rtol=1e-4):
33323332
"""
33333333
:param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet
33343334
:param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet
@@ -3344,7 +3344,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33443344
subnet_layer_dict.setdefault("from", ["data:source"])
33453345
rec_layer_dict = {
33463346
"class": "rec",
3347-
"from": ["data"],
3347+
"from": ["data"] if from_ is None else [from_],
33483348
"unit": {"output": subnet_layer_dict},
33493349
"n_out": n_out,
33503350
"is_output_layer": True
@@ -3598,6 +3598,31 @@ def test_reclayer_optimize_out_access_split():
35983598
other_subnet_layers={"split": {"class": "split", "from": ["data:source"], "size_splits": [5, 8]}})
35993599

36003600

3601+
def test_reclayer_optimize_out_slice_nd():
3602+
def random_start_positions(source, **kwargs):
3603+
import tensorflow as tf
3604+
enc = source(0, as_data=True, enforce_batch_major=True, auto_convert=False)
3605+
enc_shape = tf.shape(enc.placeholder)
3606+
enc_time_dim = enc_shape[enc.time_dim_axis]
3607+
return tf.random.uniform(enc_shape[:-1], 0, enc_time_dim-2, dtype=tf.dtypes.int32)
3608+
3609+
check_reclayer_optimize_out(
3610+
{"class": "linear", "activation": None, "from": ["encoder_reduced"]},
3611+
from_="position",
3612+
other_subnet_layers={
3613+
"window": {"class": "slice_nd", "from": "base:encoder", "start": "data:source", "size": None, "min_size": 1, "is_output_layer": True},
3614+
"encoder_reduced": {"class": "reduce", "mode": "sum", "axis": "T", "from": ["base:encoder"], "is_output_layer": True}
3615+
},
3616+
shared_base_net={
3617+
"encoder": {"class": "copy", "from": "data", "is_output_layer": True},
3618+
"position": {
3619+
"class": "eval", "from": "encoder", "is_output_layer": True,
3620+
"eval": random_start_positions,
3621+
"out_type": {"batch_dim_axis": 0, "time_dim_axis": 1, "shape": (None,), "sparse": True, "dtype": "int32", "dim": None}}
3622+
}
3623+
)
3624+
3625+
36013626
def test_reclayer_att_with_kv_in_rec():
36023627
net_dict = {
36033628
'decision': {'class': 'decide', 'from': ['output'], 'loss': 'edit_distance', 'loss_opts': {}, 'target': 'classes'},

tests/test_TFUtil.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,6 +1637,8 @@ def test_windowed_nd_big():
16371637

16381638

16391639
def naive_slice_nd(x, start, size):
1640+
# old implementation, check out naive_slice_nd2
1641+
16401642
slices_shape = [x.shape[0], size] + list(x.shape)[2:]
16411643
ys = numpy.zeros(shape=slices_shape)
16421644
for i in range(len(start)):
@@ -1653,6 +1655,63 @@ def naive_slice_nd(x, start, size):
16531655
return ys
16541656

16551657

1658+
def naive_slice_nd2(x, start, size):
1659+
# Assuming that x: [B, T1, T2, .., Tn, D] and start: [B, T1, .., Tn-1]
1660+
# i.e. the dimensions of x and start are ordered accordingly.
1661+
# (Otherwise we should require the slice axis too.)
1662+
1663+
len_common_dims = len(start.shape)
1664+
slice_shape = (size,) + x.shape[len_common_dims+1:]
1665+
result_shape = start.shape[0:len_common_dims] + slice_shape # shape of output
1666+
result = numpy.zeros(result_shape)
1667+
1668+
slice_axis_dim = x.shape[len_common_dims] # dim of axis being sliced
1669+
for index, start_position in numpy.ndenumerate(start):
1670+
end_position = min(start_position+size, slice_axis_dim) # padding required
1671+
1672+
# no padding
1673+
padding = ((0,0),)
1674+
for i in range(1, len(slice_shape)):
1675+
padding += ((0, 0),)
1676+
1677+
# if required replace the first padding tuple, which corresponds to the slice axis
1678+
if end_position < start_position+size:
1679+
padding = ((0,size - end_position + start_position),) + padding[1:]
1680+
result[index] = numpy.pad(x[index][start_position:end_position], padding, mode='constant', constant_values=0)
1681+
return result
1682+
1683+
1684+
def test_slice_nd_multi_dim():
1685+
n_batch = 2
1686+
n_time_1 = 2
1687+
n_time_2 = 3 # slice axis
1688+
n_dim = 2
1689+
size = 2
1690+
source = numpy.arange(24, dtype=numpy.float32).reshape(n_batch, n_time_1, n_time_2, n_dim).astype("float32")
1691+
start = numpy.array([[0,1],[1,2]]).astype("int32")
1692+
naive = naive_slice_nd2(source, start, size)
1693+
source_tf = tf.constant(source)
1694+
real = slice_nd2(source_tf, start=start, size=size).eval()
1695+
print("source:")
1696+
print(source)
1697+
print("naive:")
1698+
print(naive)
1699+
print("real:")
1700+
print(real)
1701+
expected_output = numpy.array(
1702+
[[[[0, 1],
1703+
[2, 3]],
1704+
[[8, 9],
1705+
[10, 11]]],
1706+
1707+
[[[14, 15],
1708+
[16, 17]],
1709+
[[22, 23],
1710+
[0, 0]]]]) # padding
1711+
numpy.testing.assert_almost_equal(naive, expected_output)
1712+
numpy.testing.assert_almost_equal(real, expected_output)
1713+
1714+
16561715
def test_slice_nd_small():
16571716
n_batch = 3
16581717
n_time = 4
@@ -1662,7 +1721,7 @@ def test_slice_nd_small():
16621721
source = numpy.arange(1, n_batch*n_time*n_dim + 1, dtype=numpy.float32).reshape(n_batch, n_time, n_dim).astype("float32")
16631722
source_tf = tf.constant(source)
16641723
naive = naive_slice_nd(source, start, size)
1665-
real = slice_nd(source_tf, start=start, size=size).eval()
1724+
real = slice_nd2(source_tf, start=start, size=size).eval()
16661725
print("source:")
16671726
print(source)
16681727
print("naive:")
@@ -1682,7 +1741,7 @@ def test_slice_nd_big():
16821741
source = numpy.random.random((n_batch, n_time, n_dim)).astype("float32")
16831742
source_tf = tf.constant(source)
16841743
naive = naive_slice_nd(source, start, size)
1685-
real = slice_nd(source_tf, start=start, size=size).eval()
1744+
real = slice_nd2(source_tf, start=start, size=size).eval()
16861745
print("source:")
16871746
print(source)
16881747
print("naive:")

0 commit comments

Comments
 (0)