@@ -864,18 +864,7 @@ def get_extern_data(data: Data) -> Tensor:
864
864
scope .extern_data [data .name ] = data
865
865
else :
866
866
assert scope .extern_data [data .name ] is data
867
- if not scope .global_batch :
868
- if nn .is_debug_eager_mode_enabled ():
869
- scope .global_batch = nn .BatchInfo .make_global_batch_info (
870
- tf .constant (3 , name = "global_batch" )) # https://xkcd.com/221/, but prime
871
- else :
872
- # We need some global batch info, and this needs a tensor (e.g. placeholder),
873
- # but we don't have any tensor yet, nor do we want to create any tensors at this point.
874
- # So we pass the dummy value -1.
875
- # Such dummy global batch info with -1 will be handled specially in RETURNN init_batch_info,
876
- # and it will be replaced with the real global batch.
877
- scope .global_batch = nn .BatchInfo .make_global_batch_info (- 1 )
878
- data .batch = scope .global_batch
867
+ data .batch = _init_global_batch ()
879
868
root_layer_name = f"data:{ data .name } "
880
869
out = _get_raw_layer_by_name (root_layer_name , scope = scope , data = data )
881
870
for tag in data .dim_tags :
@@ -958,6 +947,23 @@ class ReturnnConstructTemplateException(Exception):
958
947
"""
959
948
960
949
950
+ def _init_global_batch () -> nn .BatchInfo :
951
+ root_name_ctx = nn .NameCtx .top ().root
952
+ if root_name_ctx .global_batch :
953
+ return root_name_ctx .global_batch
954
+ if nn .is_debug_eager_mode_enabled ():
955
+ root_name_ctx .global_batch = nn .BatchInfo .make_global_batch_info (
956
+ tf .constant (3 , name = "global_batch" )) # https://xkcd.com/221/, but prime
957
+ else :
958
+ # We need some global batch info, and this needs a tensor (e.g. placeholder),
959
+ # but we don't have any tensor yet, nor do we want to create any tensors at this point.
960
+ # So we pass the dummy value -1.
961
+ # Such dummy global batch info with -1 will be handled specially in RETURNN init_batch_info,
962
+ # and it will be replaced with the real global batch.
963
+ root_name_ctx .global_batch = nn .BatchInfo .make_global_batch_info (- 1 )
964
+ return root_name_ctx .global_batch
965
+
966
+
961
967
def _data_from_layer_dict (layer_dict : LayerDictRaw , * , tensor : Tensor ) -> Data :
962
968
"""
963
969
Use RETURNN layer_class.get_out_data_from_opts to get the :class:`Data`.
@@ -976,6 +982,7 @@ def _data_from_layer_dict(layer_dict: LayerDictRaw, *, tensor: Tensor) -> Data:
976
982
train_flag = True , # should not have an effect usually for templates, except maybe in debug-eager-mode
977
983
inside_rec_time_dim = loop .axis if loop else None ,
978
984
control_flow_ctx = nn .NameCtx .inner_control_flow ())
985
+ net .extern_data .set_batch_info (_init_global_batch ())
979
986
980
987
ref_to_layer_name = {} # type: Dict[nn.NameCtx, str]
981
988
0 commit comments