@@ -1620,6 +1620,45 @@ def test_SplitDimsLayer_split_feature():
1620
1620
numpy .testing .assert_almost_equal (out_v , in_v .reshape (out_v .shape ))
1621
1621
1622
1622
1623
+ def test_out_shape ():
1624
+ # https://github.com/rwth-i6/returnn/issues/706
1625
+ # Note: Using SplitDimsLayer would also be nice to test out_shape. Or any layer which creates a new dim.
1626
+ # However, for that, we need https://github.com/rwth-i6/returnn/issues/597 first.
1627
+ from returnn .tf .util .data import BatchDim , VerifyOutShapeException
1628
+ time_dim = DimensionTag (kind = DimensionTag .Types .Spatial , description = "time" )
1629
+ feat_dim = DimensionTag (kind = DimensionTag .Types .Feature , description = "feature" , dimension = 10 )
1630
+ config = Config ({
1631
+ "extern_data" : {
1632
+ "data" : {"dim_tags" : [BatchDim , time_dim , feat_dim ]} # [B,T,D]
1633
+ }
1634
+ })
1635
+ with make_scope () as session :
1636
+ net = TFNetwork (config = config )
1637
+ net .construct_from_dict ({
1638
+ "output" : {
1639
+ 'class' : 'softmax_over_spatial' , 'from' : 'data' ,
1640
+ "out_shape" : {BatchDim , time_dim , feat_dim }
1641
+ }
1642
+ })
1643
+ out = net .get_default_output_layer ().output
1644
+ session .run (out .placeholder , feed_dict = make_feed_dict (net .extern_data ))
1645
+ with make_scope ():
1646
+ other_feat_dim = DimensionTag (kind = DimensionTag .Types .Feature , description = "other-feature" , dimension = 10 )
1647
+ net = TFNetwork (config = config )
1648
+ # noinspection PyBroadException
1649
+ try :
1650
+ net .construct_from_dict ({
1651
+ "output" : {
1652
+ 'class' : 'softmax_over_spatial' , 'from' : 'data' ,
1653
+ "out_shape" : {BatchDim , time_dim , other_feat_dim }
1654
+ }
1655
+ })
1656
+ except VerifyOutShapeException as exc :
1657
+ print ("Got expected exception: %r" % exc )
1658
+ else :
1659
+ raise Exception ("Expected an exception but did not get any." )
1660
+
1661
+
1623
1662
def _check_MergeDimsLayer (session , in_data_opts , in_static_shape , opts , out_data_shape , out_static_shape ,
1624
1663
in_sizes = None , out_sizes = None ):
1625
1664
"""
0 commit comments