Skip to content

Commit 8717216

Browse files
committed
RecLayer subnet output check format
1 parent 363f335 commit 8717216

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

returnn/tf/layers/rec.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2587,8 +2587,9 @@ def cond(i, net_vars, acc_tas, seq_len_info=None):
25872587
if output_layer:
25882588
assert isinstance(output_layer, LayerBase)
25892589
output_data = output_layer.output.copy_as_time_major()
2590-
assert 0 in output_data.size_placeholder
2591-
rec_layer.output.size_placeholder = output_data.size_placeholder.copy()
2590+
# No need to copy size_placeholder, as we have dim tags now.
2591+
# However, we should check that the format is right.
2592+
assert rec_layer.output.dim_tags == output_data.dim_tags
25922593
output = output_data.placeholder
25932594
else:
25942595
assert seq_len is not None

0 commit comments

Comments
 (0)