@@ -1637,6 +1637,8 @@ def test_windowed_nd_big():
1637
1637
1638
1638
1639
1639
def naive_slice_nd (x , start , size ):
1640
+ # old implementation, check out naive_slice_nd2
1641
+
1640
1642
slices_shape = [x .shape [0 ], size ] + list (x .shape )[2 :]
1641
1643
ys = numpy .zeros (shape = slices_shape )
1642
1644
for i in range (len (start )):
@@ -1653,6 +1655,63 @@ def naive_slice_nd(x, start, size):
1653
1655
return ys
1654
1656
1655
1657
1658
+ def naive_slice_nd2 (x , start , size ):
1659
+ # Assuming that x: [B, T1, T2, .., Tn, D] and start: [B, T1, .., Tn-1]
1660
+ # i.e. the dimensions of x and start are ordered accordingly.
1661
+ # (Otherwise we should require the slice axis too.)
1662
+
1663
+ len_common_dims = len (start .shape )
1664
+ slice_shape = (size ,) + x .shape [len_common_dims + 1 :]
1665
+ result_shape = start .shape [0 :len_common_dims ] + slice_shape # shape of output
1666
+ result = numpy .zeros (result_shape )
1667
+
1668
+ slice_axis_dim = x .shape [len_common_dims ] # dim of axis being sliced
1669
+ for index , start_position in numpy .ndenumerate (start ):
1670
+ end_position = min (start_position + size , slice_axis_dim ) # padding required
1671
+
1672
+ # no padding
1673
+ padding = ((0 ,0 ),)
1674
+ for i in range (1 , len (slice_shape )):
1675
+ padding += ((0 , 0 ),)
1676
+
1677
+ # if required replace the first padding tuple, which corresponds to the slice axis
1678
+ if end_position < start_position + size :
1679
+ padding = ((0 ,size - end_position + start_position ),) + padding [1 :]
1680
+ result [index ] = numpy .pad (x [index ][start_position :end_position ], padding , mode = 'constant' , constant_values = 0 )
1681
+ return result
1682
+
1683
+
1684
+ def test_slice_nd_multi_dim ():
1685
+ n_batch = 2
1686
+ n_time_1 = 2
1687
+ n_time_2 = 3 # slice axis
1688
+ n_dim = 2
1689
+ size = 2
1690
+ source = numpy .arange (24 , dtype = numpy .float32 ).reshape (n_batch , n_time_1 , n_time_2 , n_dim ).astype ("float32" )
1691
+ start = numpy .array ([[0 ,1 ],[1 ,2 ]]).astype ("int32" )
1692
+ naive = naive_slice_nd2 (source , start , size )
1693
+ source_tf = tf .constant (source )
1694
+ real = slice_nd2 (source_tf , start = start , size = size ).eval ()
1695
+ print ("source:" )
1696
+ print (source )
1697
+ print ("naive:" )
1698
+ print (naive )
1699
+ print ("real:" )
1700
+ print (real )
1701
+ expected_output = numpy .array (
1702
+ [[[[0 , 1 ],
1703
+ [2 , 3 ]],
1704
+ [[8 , 9 ],
1705
+ [10 , 11 ]]],
1706
+
1707
+ [[[14 , 15 ],
1708
+ [16 , 17 ]],
1709
+ [[22 , 23 ],
1710
+ [0 , 0 ]]]]) # padding
1711
+ numpy .testing .assert_almost_equal (naive , expected_output )
1712
+ numpy .testing .assert_almost_equal (real , expected_output )
1713
+
1714
+
1656
1715
def test_slice_nd_small ():
1657
1716
n_batch = 3
1658
1717
n_time = 4
@@ -1662,7 +1721,7 @@ def test_slice_nd_small():
1662
1721
source = numpy .arange (1 , n_batch * n_time * n_dim + 1 , dtype = numpy .float32 ).reshape (n_batch , n_time , n_dim ).astype ("float32" )
1663
1722
source_tf = tf .constant (source )
1664
1723
naive = naive_slice_nd (source , start , size )
1665
- real = slice_nd (source_tf , start = start , size = size ).eval ()
1724
+ real = slice_nd2 (source_tf , start = start , size = size ).eval ()
1666
1725
print ("source:" )
1667
1726
print (source )
1668
1727
print ("naive:" )
@@ -1682,7 +1741,7 @@ def test_slice_nd_big():
1682
1741
source = numpy .random .random ((n_batch , n_time , n_dim )).astype ("float32" )
1683
1742
source_tf = tf .constant (source )
1684
1743
naive = naive_slice_nd (source , start , size )
1685
- real = slice_nd (source_tf , start = start , size = size ).eval ()
1744
+ real = slice_nd2 (source_tf , start = start , size = size ).eval ()
1686
1745
print ("source:" )
1687
1746
print (source )
1688
1747
print ("naive:" )
0 commit comments