Skip to content

Commit 5822bee

Browse files
committed
test_out_shape
Test for #706
1 parent 86c846f commit 5822bee

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

tests/test_TFNetworkLayer.py

+39
Original file line numberDiff line numberDiff line change
@@ -1620,6 +1620,45 @@ def test_SplitDimsLayer_split_feature():
16201620
numpy.testing.assert_almost_equal(out_v, in_v.reshape(out_v.shape))
16211621

16221622

1623+
def test_out_shape():
1624+
# https://github.com/rwth-i6/returnn/issues/706
1625+
# Note: Using SplitDimsLayer would also be nice to test out_shape. Or any layer which creates a new dim.
1626+
# However, for that, we need https://github.com/rwth-i6/returnn/issues/597 first.
1627+
from returnn.tf.util.data import BatchDim, VerifyOutShapeException
1628+
time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="time")
1629+
feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="feature", dimension=10)
1630+
config = Config({
1631+
"extern_data": {
1632+
"data": {"dim_tags": [BatchDim, time_dim, feat_dim]} # [B,T,D]
1633+
}
1634+
})
1635+
with make_scope() as session:
1636+
net = TFNetwork(config=config)
1637+
net.construct_from_dict({
1638+
"output": {
1639+
'class': 'softmax_over_spatial', 'from': 'data',
1640+
"out_shape": {BatchDim, time_dim, feat_dim}
1641+
}
1642+
})
1643+
out = net.get_default_output_layer().output
1644+
session.run(out.placeholder, feed_dict=make_feed_dict(net.extern_data))
1645+
with make_scope():
1646+
other_feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="other-feature", dimension=10)
1647+
net = TFNetwork(config=config)
1648+
# noinspection PyBroadException
1649+
try:
1650+
net.construct_from_dict({
1651+
"output": {
1652+
'class': 'softmax_over_spatial', 'from': 'data',
1653+
"out_shape": {BatchDim, time_dim, other_feat_dim}
1654+
}
1655+
})
1656+
except VerifyOutShapeException as exc:
1657+
print("Got expected exception: %r" % exc)
1658+
else:
1659+
raise Exception("Expected an exception but did not get any.")
1660+
1661+
16231662
def _check_MergeDimsLayer(session, in_data_opts, in_static_shape, opts, out_data_shape, out_static_shape,
16241663
in_sizes=None, out_sizes=None):
16251664
"""

0 commit comments

Comments
 (0)