Skip to content

Commit 75a6fbc

Browse files
committed
_data_from_layer_dict fix batch in some cases
1 parent 283090d commit 75a6fbc

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

nn/base.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -864,18 +864,7 @@ def get_extern_data(data: Data) -> Tensor:
864864
scope.extern_data[data.name] = data
865865
else:
866866
assert scope.extern_data[data.name] is data
867-
if not scope.global_batch:
868-
if nn.is_debug_eager_mode_enabled():
869-
scope.global_batch = nn.BatchInfo.make_global_batch_info(
870-
tf.constant(3, name="global_batch")) # https://xkcd.com/221/, but prime
871-
else:
872-
# We need some global batch info, and this needs a tensor (e.g. placeholder),
873-
# but we don't have any tensor yet, nor do we want to create any tensors at this point.
874-
# So we pass the dummy value -1.
875-
# Such dummy global batch info with -1 will be handled specially in RETURNN init_batch_info,
876-
# and it will be replaced with the real global batch.
877-
scope.global_batch = nn.BatchInfo.make_global_batch_info(-1)
878-
data.batch = scope.global_batch
867+
data.batch = _init_global_batch()
879868
root_layer_name = f"data:{data.name}"
880869
out = _get_raw_layer_by_name(root_layer_name, scope=scope, data=data)
881870
for tag in data.dim_tags:
@@ -958,6 +947,23 @@ class ReturnnConstructTemplateException(Exception):
958947
"""
959948

960949

950+
def _init_global_batch() -> nn.BatchInfo:
951+
root_name_ctx = nn.NameCtx.top().root
952+
if root_name_ctx.global_batch:
953+
return root_name_ctx.global_batch
954+
if nn.is_debug_eager_mode_enabled():
955+
root_name_ctx.global_batch = nn.BatchInfo.make_global_batch_info(
956+
tf.constant(3, name="global_batch")) # https://xkcd.com/221/, but prime
957+
else:
958+
# We need some global batch info, and this needs a tensor (e.g. placeholder),
959+
# but we don't have any tensor yet, nor do we want to create any tensors at this point.
960+
# So we pass the dummy value -1.
961+
# Such dummy global batch info with -1 will be handled specially in RETURNN init_batch_info,
962+
# and it will be replaced with the real global batch.
963+
root_name_ctx.global_batch = nn.BatchInfo.make_global_batch_info(-1)
964+
return root_name_ctx.global_batch
965+
966+
961967
def _data_from_layer_dict(layer_dict: LayerDictRaw, *, tensor: Tensor) -> Data:
962968
"""
963969
Use RETURNN layer_class.get_out_data_from_opts to get the :class:`Data`.
@@ -976,6 +982,7 @@ def _data_from_layer_dict(layer_dict: LayerDictRaw, *, tensor: Tensor) -> Data:
976982
train_flag=True, # should not have an effect usually for templates, except maybe in debug-eager-mode
977983
inside_rec_time_dim=loop.axis if loop else None,
978984
control_flow_ctx=nn.NameCtx.inner_control_flow())
985+
net.extern_data.set_batch_info(_init_global_batch())
979986

980987
ref_to_layer_name = {} # type: Dict[nn.NameCtx, str]
981988

0 commit comments

Comments
 (0)