Skip to content

Commit 8591808

Browse files
committed
DimensionTag get_for_batch size identity fix control flow ctx
1 parent f529f98 commit 8591808

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

returnn/tf/util/data.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,11 @@ def get_for_batch(self, batch):
162162
# when there are different beams with same beam size!
163163
# This breaks the current logic in get_tag_from_size_tensor.
164164
# As a workaround, we make an explicit new tensor here.
165-
from .basic import get_valid_scope_name_from_str
166-
dyn_size_ext.placeholder = tf.identity(
167-
dyn_size_ext.placeholder,
168-
name=get_valid_scope_name_from_str("%s_size_beam_%s" % (dyn_size_ext.name, batch.beam.name)))
165+
from .basic import get_valid_scope_name_from_str, same_control_flow_ctx
166+
with same_control_flow_ctx(dyn_size_ext.placeholder):
167+
dyn_size_ext.placeholder = tf.identity(
168+
dyn_size_ext.placeholder,
169+
name=get_valid_scope_name_from_str("%s_identity_for_beam_%s" % (dyn_size_ext.name, batch.beam.name)))
169170
dyn_size_ext.placeholder._RETURNN_dyn_size_beam = batch.beam
170171
dyn_size_ext.placeholder._RETURNN_beam_expanded_base_data = beam_expanded_base_data
171172
else:

0 commit comments

Comments
 (0)