Skip to content

Commit c099773

Browse files
committed
test_Data_verify_out_shape_optional_implicit_dim
#1153
1 parent 8ae5332 commit c099773

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/test_TFUtil.py

+20
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,26 @@ def test_Data_copy_feat_with_vocab():
12231223
assert data2.vocab is vocab
12241224

12251225

1226+
def test_Data_verify_out_shape_optional_implicit_dim():
1227+
# https://github.com/rwth-i6/returnn/issues/1153
1228+
from returnn.tf.util.data import batch_dim, FeatureDim, SpatialDim, BatchInfo, VerifyOutShapeException
1229+
batch = BatchInfo.make_global_batch_info(-1)
1230+
time_dim = SpatialDim("time")
1231+
time_dim.batch = batch
1232+
time_dim.dyn_size_ext = Data("dyn_size_ext", dim_tags=[batch_dim], dtype="int32", batch=batch)
1233+
feat_dim = FeatureDim("feat", dimension=3)
1234+
x = Data("x", dim_tags=[time_dim, feat_dim])
1235+
try:
1236+
x.verify_out_shape({time_dim, feat_dim})
1237+
except VerifyOutShapeException as exc:
1238+
print("Got expected exception:", exc)
1239+
assert "Missing dims" in str(exc)
1240+
else:
1241+
raise Exception("did not get expected exception")
1242+
# This should not raise an exception:
1243+
x.verify_out_shape({time_dim, feat_dim}, allow_missing_implicit_dims=True)
1244+
1245+
12261246
def test_Dim_copy():
12271247
# https://github.com/rwth-i6/returnn/issues/860
12281248
import copy

0 commit comments

Comments
 (0)