Skip to content

Commit 283090d

Browse files
committed
get_extern_data global batch cleanup
1 parent 1d17547 commit 283090d

File tree

1 file changed

+12
-21
lines changed

1 file changed

+12
-21
lines changed

nn/base.py

+12-21
Original file line numberDiff line numberDiff line change
@@ -864,27 +864,18 @@ def get_extern_data(data: Data) -> Tensor:
864864
scope.extern_data[data.name] = data
865865
else:
866866
assert scope.extern_data[data.name] is data
867-
for tag in data.dim_tags:
868-
# noinspection PyProtectedMember
869-
tag._validate_in_current_graph()
870-
if tag.is_batch_dim():
871-
data.batch = tag.batch
872-
if data.have_batch_axis():
873-
if not scope.global_batch:
874-
if data.batch:
875-
scope.global_batch = data.batch
876-
elif nn.is_debug_eager_mode_enabled():
877-
scope.global_batch = nn.BatchInfo.make_global_batch_info(
878-
tf.constant(3, name="global_batch")) # https://xkcd.com/221/, but prime
879-
else:
880-
# We need some global batch info, and this needs a tensor (e.g. placeholder),
881-
# but we don't have any tensor yet, nor do we want to create any tensors at this point.
882-
# So we pass the dummy value -1.
883-
# Such dummy global batch info with -1 will be handled specially in RETURNN init_batch_info,
884-
# and it will be replaced with the real global batch.
885-
scope.global_batch = nn.BatchInfo.make_global_batch_info(-1)
886-
if not data.batch:
887-
data.batch = scope.global_batch
867+
if not scope.global_batch:
868+
if nn.is_debug_eager_mode_enabled():
869+
scope.global_batch = nn.BatchInfo.make_global_batch_info(
870+
tf.constant(3, name="global_batch")) # https://xkcd.com/221/, but prime
871+
else:
872+
# We need some global batch info, and this needs a tensor (e.g. placeholder),
873+
# but we don't have any tensor yet, nor do we want to create any tensors at this point.
874+
# So we pass the dummy value -1.
875+
# Such dummy global batch info with -1 will be handled specially in RETURNN init_batch_info,
876+
# and it will be replaced with the real global batch.
877+
scope.global_batch = nn.BatchInfo.make_global_batch_info(-1)
878+
data.batch = scope.global_batch
888879
root_layer_name = f"data:{data.name}"
889880
out = _get_raw_layer_by_name(root_layer_name, scope=scope, data=data)
890881
for tag in data.dim_tags:

0 commit comments

Comments
 (0)