Skip to content

Commit 0189da5

Browse files
authored
Data same_dim_tags_as fix auto_create_placeholders (#1159)
Don't first create a new size placeholder and then later call declare_same_as. Esp this is required when declare_same_as becomes stricter (#1143). Fix wrong batch info: The dim tag could have an old invalid batch info. E.g. the global batch_dim when it comes from an old run. If we really need this, we should validate the dim tag first. But probably it's better to remove it and clean it up. Engine reset global batch dim.
1 parent a6236c5 commit 0189da5

File tree

4 files changed

+40
-26
lines changed

4 files changed

+40
-26
lines changed

returnn/tf/engine.py

+2
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,8 @@ def _init_network(self, net_desc, epoch=None):
12931293
use_dataset_pipeline = False
12941294
if self.config.is_true("dataset_pipeline"):
12951295
use_dataset_pipeline = True
1296+
from returnn.tf.util.data import batch_dim
1297+
batch_dim.batch = None # make sure it is reset
12961298
extern_data = ExternData()
12971299
extern_data.init_from_config(config=self.config, auto_create_placeholders=not use_dataset_pipeline)
12981300
if use_dataset_pipeline:

returnn/tf/util/data.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -2823,11 +2823,6 @@ def __init__(self, name,
28232823
if time_dim_axis is NotSpecified:
28242824
time_dim_axis = _default_time_dim_axis_dim_tags(dim_tags)
28252825
dim_tags = tuple(dim_tags)
2826-
if auto_create_placeholders:
2827-
_auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=dim_tags)
2828-
if batch_dim_axis_ is not None:
2829-
if dim_tags[batch_dim_axis_].batch and not self._batch:
2830-
self._batch = dim_tags[batch_dim_axis_].batch
28312826
del shape_
28322827
del batch_dim_axis_
28332828
else:
@@ -2846,7 +2841,7 @@ def __init__(self, name,
28462841
dim_tags = _infer_dim_tags_tuple_from_shape(
28472842
shape, batch_dim_axis=batch_dim_axis, time_dim_axis=time_dim_axis, feature_dim_axis=feature_dim_axis,
28482843
size_placeholder=size_placeholder, name=name,
2849-
auto_create_placeholders=auto_create_placeholders,
2844+
extern_data=auto_create_placeholders,
28502845
dim_tags=dim_tags, sparse=sparse)
28512846
del batch_dim_axis
28522847
del shape
@@ -2893,8 +2888,10 @@ def __init__(self, name,
28932888
base_tag = self._dim_tags[_axis]
28942889
if base_tag != _dim_tag:
28952890
base_tag.declare_same_as(_dim_tag)
2896-
if _dim_tag.dyn_size is not None:
2897-
self.set_dynamic_size(_axis, _dim_tag.dyn_size)
2891+
self._dim_tags = self._dim_tags[:_axis] + (_dim_tag,) + self._dim_tags[_axis + 1:]
2892+
if auto_create_placeholders:
2893+
# Do that after same_dim_tags_as handling.
2894+
_auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=self._dim_tags)
28982895
self._adapt_batch_consistent_dim_tags()
28992896
self.sanity_check(assume_complete=False)
29002897

@@ -5780,7 +5777,7 @@ def _infer_dim_tags_tuple_from_shape(
57805777
size_placeholder,
57815778
dim_tags,
57825779
name,
5783-
auto_create_placeholders
5780+
extern_data
57845781
):
57855782
"""
57865783
:param tuple[int|None]|list[int|None] shape: this is without batch-dim-axis
@@ -5790,7 +5787,7 @@ def _infer_dim_tags_tuple_from_shape(
57905787
:param bool sparse:
57915788
:param dict[int,tf.Tensor]|None size_placeholder: key is axis without batch-dim
57925789
:param dict[int,Dim]|None dim_tags: some existing explicitly specified dim tags. key is axis with batch-dim
5793-
:param bool auto_create_placeholders:
5790+
:param bool extern_data:
57945791
:param str name:
57955792
:return: dim tags tuple
57965793
:rtype: tuple[Dim]
@@ -5808,8 +5805,6 @@ def _infer_dim_tags_tuple_from_shape(
58085805
dim_tags = dim_tags.copy() if dim_tags else {}
58095806
if batch_dim_axis is not None and batch_dim_axis not in dim_tags:
58105807
dim_tags[batch_dim_axis] = Dim(kind=Dim.Types.Batch, description="batch:%s" % name)
5811-
# noinspection PyShadowingNames
5812-
batch_dim = dim_tags[batch_dim_axis] if batch_dim_axis is not None else None
58135808
# Note: Consistent to Data.get_dim_tag,
58145809
# prefer interpretation as spatial axis if there is a dynamic size or this is marked as time axis.
58155810
if size_placeholder:
@@ -5833,7 +5828,7 @@ def _infer_dim_tags_tuple_from_shape(
58335828
axis_wo_b = _get_axis_wo_b(axis, batch_dim_axis=batch_dim_axis)
58345829
dyn_size = size_placeholder.get(axis_wo_b) if (size_placeholder and axis_wo_b is not None) else None
58355830
dim = batch_shape[axis]
5836-
if auto_create_placeholders and dim is None and dyn_size is None and axis != batch_dim_axis:
5831+
if extern_data and dim is None and dyn_size is None and axis != batch_dim_axis:
58375832
if not tag:
58385833
if axis == time_dim_axis:
58395834
tag_name = "time"
@@ -5845,7 +5840,6 @@ def _infer_dim_tags_tuple_from_shape(
58455840
# This is such that Dim.is_equal behaves as before, e.g. in Data.get_common_data.
58465841
kind=Dim.Types.Spatial)
58475842
dim_tags[axis] = tag
5848-
_create_size_placeholder(name=name, axis_wo_b=axis_wo_b, tag=tag, batch_dim=batch_dim)
58495843
dyn_size = tag.dyn_size
58505844
if tag:
58515845
# Just some sanity checks.

tests/test_TFNetworkLayer.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -2223,12 +2223,13 @@ def test_SplitDimsLayer_dim_tags():
22232223
feat_dim = FeatureDim("feat", 3)
22242224
config = Config({
22252225
"extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}}})
2226-
net = TFNetwork(config=config)
2227-
net.construct_from_dict({
2228-
"output": {
2229-
'class': 'split_dims', 'from': 'data', 'axis': time_dim, 'dims': [rem_dim, window_dim],
2230-
'out_shape': {batch_dim, rem_dim, window_dim, feat_dim}}
2231-
})
2226+
with make_scope():
2227+
net = TFNetwork(config=config)
2228+
net.construct_from_dict({
2229+
"output": {
2230+
'class': 'split_dims', 'from': 'data', 'axis': time_dim, 'dims': [rem_dim, window_dim],
2231+
'out_shape': {batch_dim, rem_dim, window_dim, feat_dim}}
2232+
})
22322233

22332234

22342235
def test_SplitDimsLayer_dim_tags_expand():
@@ -2238,12 +2239,13 @@ def test_SplitDimsLayer_dim_tags_expand():
22382239
expand_dim = SpatialDim("expand_dim", 1)
22392240
config = Config({
22402241
"extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}}})
2241-
net = TFNetwork(config=config)
2242-
net.construct_from_dict({
2243-
"output": {
2244-
'class': 'split_dims', 'from': 'data', 'axis': feat_dim, 'dims': [feat_dim, expand_dim],
2245-
'out_shape': {batch_dim, time_dim, feat_dim, expand_dim}}
2246-
})
2242+
with make_scope():
2243+
net = TFNetwork(config=config)
2244+
net.construct_from_dict({
2245+
"output": {
2246+
'class': 'split_dims', 'from': 'data', 'axis': feat_dim, 'dims': [feat_dim, expand_dim],
2247+
'out_shape': {batch_dim, time_dim, feat_dim, expand_dim}}
2248+
})
22472249

22482250

22492251
def test_SplitDimsLayer_dim_tags_split_batch_simple():

tests/test_TFUtil.py

+16
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,22 @@ def test_Data_verify_out_shape_optional_implicit_dim():
12431243
x.verify_out_shape({time_dim, feat_dim}, allow_missing_implicit_dims=True)
12441244

12451245

1246+
def test_Data_auto_create_placeholders_same_dim_tags_as_existing():
1247+
# Came up via: https://github.com/rwth-i6/returnn/pull/1143
1248+
n_out = 3
1249+
time_tag = SpatialDim("time")
1250+
with tf.Graph().as_default() as graph, tf_compat.v1.Session(graph=graph) as session:
1251+
assert isinstance(graph, tf.Graph)
1252+
data = Data("data", dim=n_out, same_dim_tags_as={"t": time_tag}, auto_create_placeholders=True)
1253+
classes = Data("classes", dim=n_out, sparse=True, same_dim_tags_as={"t": time_tag}, auto_create_placeholders=True)
1254+
assert time_tag.dyn_size is not None # this is not so relevant and might change
1255+
seq_len = time_tag.dyn_size
1256+
assert seq_len is data.get_sequence_lengths() is classes.get_sequence_lengths()
1257+
assert seq_len.op.type == "Placeholder"
1258+
placeholder_ops = [op for op in graph.get_operations() if op.type == "Placeholder"]
1259+
assert_equal(set(placeholder_ops), {data.placeholder.op, classes.placeholder.op, time_tag.dyn_size.op})
1260+
1261+
12461262
def test_Dim_copy():
12471263
# https://github.com/rwth-i6/returnn/issues/860
12481264
import copy

0 commit comments

Comments
 (0)