Skip to content

Commit 96fedc4

Browse files
committed
Dim auto_generated flag (#950)
Allows for better Dim is_equal which does not rely on the description. #634
1 parent d5803d2 commit 96fedc4

File tree

5 files changed

+91
-71
lines changed

5 files changed

+91
-71
lines changed

returnn/tf/layers/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def _base_get_out_data_from_opts(cls, network, name,
399399
feature_dim_tag = out_dim
400400
else:
401401
dim = out_type.get("dim", None)
402-
feature_dim_tag = FeatureDim("%s:feature-dense" % name, dim)
402+
feature_dim_tag = FeatureDim("%s:feature-dense" % name, dim, auto_generated=True)
403403
if feature_dim_axis in (NotSpecified, None):
404404
if sources_data.feature_dim_axis is None:
405405
feature_dim_axis = len(dim_tags)

returnn/tf/layers/basic.py

+47-39
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,7 @@ def get_out_data_from_opts(
10561056
if out_dim:
10571057
assert out_dim.dimension == new_dim
10581058
else:
1059-
out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim)
1059+
out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim, auto_generated=True)
10601060
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim, name="%s_output" % name)
10611061

10621062

@@ -1236,7 +1236,7 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, axis="T
12361236
out_spatial_dim = Dim(
12371237
kind=Dim.Types.Spatial,
12381238
description="sliced-time:%s" % name,
1239-
dimension=size)
1239+
dimension=size, auto_generated=True)
12401240
gather_positions_data = gather_positions_data.copy_add_dim_by_tag(
12411241
out_spatial_dim, unbroadcast=True, axis=start_data.batch_ndim)
12421242
position = InternalLayer(
@@ -2544,7 +2544,9 @@ def get_out_data_from_opts(cls, name, shape, dtype="float32", **kwargs):
25442544
:param str dtype:
25452545
:rtype: Data
25462546
"""
2547-
dim_tags = [d if isinstance(d, Dim) else SpatialDim("%s:dim%i" % (name, i), d) for i, d in enumerate(shape)]
2547+
dim_tags = [
2548+
d if isinstance(d, Dim) else SpatialDim("%s:dim%i" % (name, i), d, auto_generated=True)
2549+
for i, d in enumerate(shape)]
25482550
return Data(name="%s_output" % name, dim_tags=dim_tags, dtype=dtype)
25492551

25502552

@@ -2630,7 +2632,7 @@ def get_out_data_from_opts(cls, name, network, shape, maxval, minval=0, dtype="i
26302632
elif isinstance(d, int):
26312633
d = Dim(
26322634
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
2633-
description="%s:static:%i" % (name, i),
2635+
description="%s:static:%i" % (name, i), auto_generated=True,
26342636
dimension=d)
26352637
else:
26362638
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
@@ -2696,12 +2698,12 @@ def get_out_data_from_opts(cls, name, limit, start=0, delta=1, dtype=None, spars
26962698
else:
26972699
dtype = "int32"
26982700
dim = len(range(start, limit, delta))
2699-
tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name)
2701+
tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name, auto_generated=True)
27002702
if out_spatial_dim:
27012703
tag.declare_same_as(out_spatial_dim)
27022704
sparse_dim = None
27032705
if sparse:
2704-
sparse_dim = SpatialDim("%s:range-indices" % name)
2706+
sparse_dim = SpatialDim("%s:range-indices" % name, auto_generated=True)
27052707
return Data(name="%s_output" % name, dim_tags=[tag], dtype=dtype, sparse_dim=sparse_dim)
27062708

27072709

@@ -2816,7 +2818,7 @@ def get_out_data_from_opts(cls, name, sources, dtype="int32", sparse=False, out_
28162818
dim_tag = Dim.get_tag_from_size_tensor(source.placeholder)
28172819
if not dim_tag:
28182820
dim_tag = Dim(
2819-
kind=Dim.Types.Spatial, description="%s_input_len" % name,
2821+
kind=Dim.Types.Spatial, description="%s_input_len" % name, auto_generated=True,
28202822
batch=source.batch, control_flow_ctx=source.control_flow_ctx,
28212823
dyn_size_ext=source)
28222824
if source.placeholder is not None:
@@ -2957,7 +2959,7 @@ def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, out_dim=None,
29572959
if out_dim:
29582960
assert out_dim.dimension == dim
29592961
else:
2960-
out_dim = FeatureDim("%s:gating" % name, dimension=dim)
2962+
out_dim = FeatureDim("%s:gating" % name, dimension=dim, auto_generated=True)
29612963
if n_out is not NotSpecified:
29622964
assert n_out == dim
29632965
return Data(
@@ -3096,15 +3098,15 @@ def get_out_data_from_opts(cls, name, network, sources, window_size=None, window
30963098
filter_size=window_size, stride=stride, dilation_rate=1, padding=padding)
30973099
out_spatial_dim = Dim(
30983100
kind=Dim.Types.Spatial, description="%s:spatial" % name,
3099-
dimension=dim, derived_from_tag=in_spatial_dim,
3101+
dimension=dim, derived_from_tag=in_spatial_dim, auto_generated=True,
31003102
batch=data.batch, control_flow_ctx=data.control_flow_ctx)
31013103
data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim)
31023104
new_dim_axis = axis + 1 # add new axis right after
31033105
if window_dim:
31043106
assert window_dim.dimension == window_size
31053107
else:
31063108
window_dim = Dim(
3107-
kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size)
3109+
kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True)
31083110
return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim, unbroadcast=True)
31093111

31103112
# noinspection PyMethodOverriding
@@ -3602,9 +3604,11 @@ def _get_axis_size_splits_num_splits(cls, name, input_data, axis=None,
36023604
err_prefix, out_dims, dim, input_data)
36033605
if not out_dims:
36043606
assert size_splits
3605-
out_dims = [Dim(
3606-
kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx),
3607-
dimension=size_splits[idx]) for idx in range(len(size_splits))]
3607+
out_dims = [
3608+
Dim(
3609+
kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx),
3610+
dimension=size_splits[idx], auto_generated=True)
3611+
for idx in range(len(size_splits))]
36083612
return axis, out_dims
36093613

36103614
def _make_split_layer(self, idx):
@@ -3884,6 +3888,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
38843888
kind=axis_dim_tag.kind,
38853889
description="%s_split_dims%i_rem" % (name, rem_dim_idx),
38863890
dimension=resolved_shape_dims[rem_dim_idx],
3891+
auto_generated=True,
38873892
derived_from_tag=axis_dim_tag,
38883893
batch=axis_dim_tag.batch, control_flow_ctx=axis_dim_tag.control_flow_ctx)
38893894
if rem_dim.dimension is None and axis_dim_tag.dyn_size_ext is not None:
@@ -3900,7 +3905,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
39003905
Dim(
39013906
kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial,
39023907
description="%s_split_dims%i" % (name, i),
3903-
dimension=shape_dim)
3908+
dimension=shape_dim, auto_generated=True)
39043909
if rem_dim is None or i != rem_dim_idx else rem_dim
39053910
for i, shape_dim in enumerate(resolved_shape_dims))
39063911
out_batch = data.batch
@@ -4175,7 +4180,7 @@ def get_out_data_from_opts(cls, name, sources, num_axes, in_dim="T", out_dims=No
41754180
assert not declare_same_sizes_as
41764181
else:
41774182
out_dims = [
4178-
SpatialDim("%s:unflatten-nd:%i" % (name, i))
4183+
SpatialDim("%s:unflatten-nd:%i" % (name, i), auto_generated=True)
41794184
for i in range(num_axes)]
41804185
if declare_same_sizes_as:
41814186
for i, other in declare_same_sizes_as.items():
@@ -4255,7 +4260,7 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs):
42554260
else:
42564261
new_dim = Dim(
42574262
kind=Dim.Types.Feature if init_axis.lower() == "f" else Dim.Types.Spatial,
4258-
description="%s_expand_dims" % name,
4263+
description="%s_expand_dims" % name, auto_generated=True,
42594264
dimension=dim)
42604265
data = data.copy_template(name="%s_output" % name)
42614266
data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis)
@@ -4411,7 +4416,7 @@ def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None,
44114416
if isinstance(repetitions, int):
44124417
out_dim = tag * repetitions
44134418
else:
4414-
out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag)
4419+
out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True)
44154420
return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim)
44164421

44174422

@@ -4821,12 +4826,12 @@ def map_axis_name(s):
48214826
pass
48224827
else:
48234828
out.sparse_dim = Dim(
4824-
kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name)
4829+
kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name, auto_generated=True)
48254830
if increase_sparse_dim:
48264831
assert out.sparse
48274832
out.sparse_dim = Dim(
48284833
kind=Dim.Types.Feature, dimension=out.sparse_dim.dimension + 1,
4829-
description="%s:inc-sparse-dim" % name)
4834+
description="%s:inc-sparse-dim" % name, auto_generated=True)
48304835
if batch_dim_base:
48314836
out.batch = batch_dim_base.output.batch
48324837
return out
@@ -5103,7 +5108,7 @@ def transform_input(cls, input_data, network, in_dim=None, in_spatial_dims=None,
51035108
cls._check_defined_in_spatial_dims(len(in_spatial_dims) == 1)
51045109
if input_expand_dims:
51055110
for i in range(input_expand_dims):
5106-
dim_tag = SpatialDim("input_expand_dims:%i" % i, dimension=1)
5111+
dim_tag = SpatialDim("input_expand_dims:%i" % i, dimension=1, auto_generated=True)
51075112
input_data = input_data.copy_add_dim_by_tag(dim_tag, unbroadcast=True)
51085113
in_spatial_dims.append(dim_tag)
51095114
if input_split_feature_dim:
@@ -5280,10 +5285,10 @@ def get_out_data_from_opts(
52805285
filter_size=filter_size[i], stride=strides[i], dilation_rate=dilation_rate[i], padding=padding)
52815286
dim_tags.append(Dim(
52825287
kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim,
5283-
derived_from_tag=old_tag, undefined=not old_tag))
5288+
derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True))
52845289
if not out_dim:
52855290
assert n_out
5286-
out_dim = FeatureDim("%s:channel" % name, dimension=n_out)
5291+
out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True)
52875292
dim_tags.append(out_dim)
52885293
feature_dim_axis = NotSpecified
52895294
# Swap the dims if the input dim order doesn't fit the flag auto_use_channel_first.
@@ -5784,10 +5789,10 @@ def get_out_data_from_opts(cls, name, sources, network,
57845789
padding=padding, output_padding=output_padding[i]) - remove_padding[i] * 2
57855790
dim_tags.append(Dim(
57865791
kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim,
5787-
derived_from_tag=old_tag, undefined=not old_tag))
5792+
derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True))
57885793
if not out_dim:
57895794
assert n_out
5790-
out_dim = FeatureDim("%s:channel" % name, dimension=n_out)
5795+
out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True)
57915796
dim_tags.append(out_dim)
57925797
return Data(
57935798
name="%s_output" % name, dim_tags=dim_tags,
@@ -6000,7 +6005,8 @@ def get_out_data_from_opts(cls, name, sources, mode="", axes=None, axis=None, ke
60006005
out_time_dim_axis = x.time_dim_axis
60016006
if keep_dims:
60026007
for i in axes:
6003-
y_dim_tags[i] = Dim(kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i))
6008+
y_dim_tags[i] = Dim(
6009+
kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i), auto_generated=True)
60046010
else:
60056011
if out_batch_dim_axis in axes:
60066012
out_batch_dim_axis = None
@@ -6201,7 +6207,7 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None,
62016207
out = common_source.copy_template(name="%s_output" % name)
62026208
if not out_spatial_dim:
62036209
out_spatial_dim = Dim(
6204-
kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources))
6210+
kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True)
62056211
assert out_spatial_dim.dimension == len(sources)
62066212
out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True)
62076213
return out
@@ -6333,7 +6339,8 @@ def get_out_data_from_opts(cls, name, sources, axes, padding=None, size=None, ke
63336339
dim_tags = list(data.dim_tags)
63346340
for i, a in enumerate(axes):
63356341
dim_tags[a] = Dim(
6336-
kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i])
6342+
kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i],
6343+
auto_generated=True)
63376344
data = data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True)
63386345
else:
63396346
assert all([d == 1 for d in res_dims])
@@ -6484,7 +6491,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, size_base
64846491
assert not out_dim
64856492
out_dim = size_base.output.get_time_dim_tag()
64866493
if not out_dim:
6487-
out_dim = (repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name)) + in_dim
6494+
out_dim = (
6495+
repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name, auto_generated=True)) + in_dim
64886496
assert out_dim.dimension == out_dim_int
64896497
x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim)
64906498
if isinstance(repeat, LayerBase):
@@ -6647,7 +6655,7 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs)
66476655
data = data.copy_move_axis(old_axis=axis, new_axis=0)
66486656
data = data.copy_with_batch_dim_axis(1)
66496657
if not out_dim:
6650-
out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name)
6658+
out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name, auto_generated=True)
66516659
data = data.copy_template_replace_dim_tag(axis=0, new_dim_tag=out_dim)
66526660
data.time_dim_axis = 0
66536661
return data
@@ -7143,7 +7151,7 @@ def find_axis(a_axis, b_axis):
71437151

71447152
if not b_var_dims and add_var2_if_empty:
71457153
b_var_dims.append(
7146-
SpatialDim("%s:dot:dummy-var2" % name, dimension=1))
7154+
SpatialDim("%s:dot:dummy-var2" % name, dimension=1, auto_generated=True))
71477155

71487156
dim_tags = list(a_rem_dims + a_var_dims + b_var_dims)
71497157
return Data(
@@ -7206,9 +7214,9 @@ def __init__(self, axis, amount, pad=True, adjust_size_info=True, **kwargs):
72067214
self.output.size_placeholder[axis_wob] + size_delta, 0, tf.shape(shifted)[axis])
72077215
from ..util.data import Dim
72087216
Dim(
7209-
kind=Dim.Types.Spatial, description="shift_axis",
7217+
kind=Dim.Types.Spatial, description="%s_shift_axis" % self.name,
72107218
dyn_size=new_size, batch=self.output.batch,
7211-
src_data=self.output, src_axis=axis)
7219+
src_data=self.output, src_axis=axis, auto_generated=True)
72127220
self.output.size_placeholder[axis_wob] = new_size
72137221

72147222
@classmethod
@@ -7227,7 +7235,7 @@ def get_out_data_from_opts(cls, name, amount, axis, pad, sources=(), **kwargs):
72277235
axis = out.get_axis_from_description(axis)
72287236
tag = out.dim_tags[axis]
72297237
dim = None if tag.dimension is None else max(0, tag.dimension - abs(amount))
7230-
tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim)
7238+
tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim, auto_generated=True)
72317239
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=tag)
72327240

72337241

@@ -7336,7 +7344,7 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwa
73367344
if out_dim:
73377345
assert out_dim.dimension == dim
73387346
else:
7339-
out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim)
7347+
out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim, auto_generated=True)
73407348
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)
73417349

73427350

@@ -7427,7 +7435,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs)
74277435
axis = out.get_axis_from_description(axis, allow_int=False)
74287436
in_dim = out.dim_tags[axis]
74297437
if not out_dim:
7430-
out_dim = Dim(kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim)
7438+
out_dim = Dim(
7439+
kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim, auto_generated=True)
74317440
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)
74327441

74337442

@@ -8599,18 +8608,17 @@ def get_out_data_from_opts(cls, name, network,
85998608
elif isinstance(d, int):
86008609
d = Dim(
86018610
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
8602-
description="%s:static:%i" % (name, i),
8611+
description="%s:static:%i" % (name, i), auto_generated=True,
86038612
dimension=d)
86048613
else:
86058614
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
86068615
dim_tags.append(d)
86078616
if add_time_axis:
86088617
dim_tags.insert(
8609-
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1))
8618+
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1, auto_generated=True))
86108619
if add_batch_axis:
86118620
dim_tags.insert(
8612-
0, Dim(
8613-
kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info()))
8621+
0, Dim(kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info()))
86148622
return Data(
86158623
name="%s_output" % name, dim_tags=dim_tags, dtype=dtype,
86168624
batch=network.get_global_batch_info() if add_batch_axis else None)

0 commit comments

Comments
 (0)