Skip to content

Commit ff4ab0d

Browse files
committed
Dim auto_generated flag
Allows for better Dim is_equal which does not rely on the description. #634
1 parent 5d3d0cb commit ff4ab0d

File tree

4 files changed

+56
-42
lines changed

4 files changed

+56
-42
lines changed

returnn/tf/layers/basic.py

+35-30
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,7 @@ def get_out_data_from_opts(
10391039
if out_dim:
10401040
assert out_dim.dimension == new_dim
10411041
else:
1042-
out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim)
1042+
out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim, auto_generated=True)
10431043
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim, name="%s_output" % name)
10441044

10451045

@@ -1219,7 +1219,7 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, axis="T
12191219
out_spatial_dim = Dim(
12201220
kind=Dim.Types.Spatial,
12211221
description="sliced-time:%s" % name,
1222-
dimension=size)
1222+
dimension=size, auto_generated=True)
12231223
gather_positions_data = gather_positions_data.copy_add_dim_by_tag(
12241224
out_spatial_dim, unbroadcast=True, axis=start_data.batch_ndim)
12251225
position = InternalLayer(
@@ -2613,7 +2613,7 @@ def get_out_data_from_opts(cls, name, network, shape, maxval, minval=0, dtype="i
26132613
elif isinstance(d, int):
26142614
d = Dim(
26152615
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
2616-
description="%s:static:%i" % (name, i),
2616+
description="%s:static:%i" % (name, i), auto_generated=True,
26172617
dimension=d)
26182618
else:
26192619
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
@@ -2679,7 +2679,7 @@ def get_out_data_from_opts(cls, name, limit, start=0, delta=1, dtype=None, spars
26792679
else:
26802680
dtype = "int32"
26812681
dim = len(range(start, limit, delta))
2682-
tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name)
2682+
tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name, auto_generated=True)
26832683
if out_spatial_dim:
26842684
tag.declare_same_as(out_spatial_dim)
26852685
sparse_dim = None
@@ -2799,7 +2799,7 @@ def get_out_data_from_opts(cls, name, sources, dtype="int32", sparse=False, out_
27992799
dim_tag = Dim.get_tag_from_size_tensor(source.placeholder)
28002800
if not dim_tag:
28012801
dim_tag = Dim(
2802-
kind=Dim.Types.Spatial, description="%s_input_len" % name,
2802+
kind=Dim.Types.Spatial, description="%s_input_len" % name, auto_generated=True,
28032803
batch=source.batch, control_flow_ctx=source.control_flow_ctx,
28042804
dyn_size_ext=source)
28052805
if source.placeholder is not None:
@@ -3079,15 +3079,15 @@ def get_out_data_from_opts(cls, name, network, sources, window_size=None, window
30793079
filter_size=window_size, stride=stride, dilation_rate=1, padding=padding)
30803080
out_spatial_dim = Dim(
30813081
kind=Dim.Types.Spatial, description="%s:spatial" % name,
3082-
dimension=dim, derived_from_tag=in_spatial_dim,
3082+
dimension=dim, derived_from_tag=in_spatial_dim, auto_generated=True,
30833083
batch=data.batch, control_flow_ctx=data.control_flow_ctx)
30843084
data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim)
30853085
new_dim_axis = axis + 1 # add new axis right after
30863086
if window_dim:
30873087
assert window_dim.dimension == window_size
30883088
else:
30893089
window_dim = Dim(
3090-
kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size)
3090+
kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True)
30913091
return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim, unbroadcast=True)
30923092

30933093
# noinspection PyMethodOverriding
@@ -3585,9 +3585,11 @@ def _get_axis_size_splits_num_splits(cls, name, input_data, axis=None,
35853585
err_prefix, out_dims, dim, input_data)
35863586
if not out_dims:
35873587
assert size_splits
3588-
out_dims = [Dim(
3589-
kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx),
3590-
dimension=size_splits[idx]) for idx in range(len(size_splits))]
3588+
out_dims = [
3589+
Dim(
3590+
kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx),
3591+
dimension=size_splits[idx], auto_generated=True)
3592+
for idx in range(len(size_splits))]
35913593
return axis, out_dims
35923594

35933595
def _make_split_layer(self, idx):
@@ -3867,6 +3869,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
38673869
kind=axis_dim_tag.kind,
38683870
description="%s_split_dims%i_rem" % (name, rem_dim_idx),
38693871
dimension=resolved_shape_dims[rem_dim_idx],
3872+
auto_generated=True,
38703873
derived_from_tag=axis_dim_tag,
38713874
batch=axis_dim_tag.batch, control_flow_ctx=axis_dim_tag.control_flow_ctx)
38723875
if rem_dim.dimension is None and axis_dim_tag.dyn_size_ext is not None:
@@ -3883,7 +3886,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
38833886
Dim(
38843887
kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial,
38853888
description="%s_split_dims%i" % (name, i),
3886-
dimension=shape_dim)
3889+
dimension=shape_dim, auto_generated=True)
38873890
if rem_dim is None or i != rem_dim_idx else rem_dim
38883891
for i, shape_dim in enumerate(resolved_shape_dims))
38893892
out_batch = data.batch
@@ -4238,7 +4241,7 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs):
42384241
else:
42394242
new_dim = Dim(
42404243
kind=Dim.Types.Feature if init_axis.lower() == "f" else Dim.Types.Spatial,
4241-
description="%s_expand_dims" % name,
4244+
description="%s_expand_dims" % name, auto_generated=True,
42424245
dimension=dim)
42434246
data = data.copy_template(name="%s_output" % name)
42444247
data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis)
@@ -4394,7 +4397,7 @@ def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None,
43944397
if isinstance(repetitions, int):
43954398
out_dim = tag * repetitions
43964399
else:
4397-
out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag)
4400+
out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True)
43984401
return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim)
43994402

44004403

@@ -4804,12 +4807,12 @@ def map_axis_name(s):
48044807
pass
48054808
else:
48064809
out.sparse_dim = Dim(
4807-
kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name)
4810+
kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name, auto_generated=True)
48084811
if increase_sparse_dim:
48094812
assert out.sparse
48104813
out.sparse_dim = Dim(
48114814
kind=Dim.Types.Feature, dimension=out.sparse_dim.dimension + 1,
4812-
description="%s:inc-sparse-dim" % name)
4815+
description="%s:inc-sparse-dim" % name, auto_generated=True)
48134816
if batch_dim_base:
48144817
out.batch = batch_dim_base.output.batch
48154818
return out
@@ -5263,7 +5266,7 @@ def get_out_data_from_opts(
52635266
filter_size=filter_size[i], stride=strides[i], dilation_rate=dilation_rate[i], padding=padding)
52645267
dim_tags.append(Dim(
52655268
kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim,
5266-
derived_from_tag=old_tag, undefined=not old_tag))
5269+
derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True))
52675270
if not out_dim:
52685271
assert n_out
52695272
out_dim = FeatureDim("%s:channel" % name, dimension=n_out)
@@ -5767,7 +5770,7 @@ def get_out_data_from_opts(cls, name, sources, network,
57675770
padding=padding, output_padding=output_padding[i]) - remove_padding[i] * 2
57685771
dim_tags.append(Dim(
57695772
kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim,
5770-
derived_from_tag=old_tag, undefined=not old_tag))
5773+
derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True))
57715774
if not out_dim:
57725775
assert n_out
57735776
out_dim = FeatureDim("%s:channel" % name, dimension=n_out)
@@ -5983,7 +5986,8 @@ def get_out_data_from_opts(cls, name, sources, mode="", axes=None, axis=None, ke
59835986
out_time_dim_axis = x.time_dim_axis
59845987
if keep_dims:
59855988
for i in axes:
5986-
y_dim_tags[i] = Dim(kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i))
5989+
y_dim_tags[i] = Dim(
5990+
kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i), auto_generated=True)
59875991
else:
59885992
if out_batch_dim_axis in axes:
59895993
out_batch_dim_axis = None
@@ -6184,7 +6188,7 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None,
61846188
out = common_source.copy_template(name="%s_output" % name)
61856189
if not out_spatial_dim:
61866190
out_spatial_dim = Dim(
6187-
kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources))
6191+
kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True)
61886192
assert out_spatial_dim.dimension == len(sources)
61896193
out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True)
61906194
return out
@@ -6316,7 +6320,8 @@ def get_out_data_from_opts(cls, name, sources, axes, padding=None, size=None, ke
63166320
dim_tags = list(data.dim_tags)
63176321
for i, a in enumerate(axes):
63186322
dim_tags[a] = Dim(
6319-
kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i])
6323+
kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i],
6324+
auto_generated=True)
63206325
data = data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True)
63216326
else:
63226327
assert all([d == 1 for d in res_dims])
@@ -6630,7 +6635,7 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs)
66306635
data = data.copy_move_axis(old_axis=axis, new_axis=0)
66316636
data = data.copy_with_batch_dim_axis(1)
66326637
if not out_dim:
6633-
out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name)
6638+
out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name, auto_generated=True)
66346639
data = data.copy_template_replace_dim_tag(axis=0, new_dim_tag=out_dim)
66356640
data.time_dim_axis = 0
66366641
return data
@@ -7189,9 +7194,9 @@ def __init__(self, axis, amount, pad=True, adjust_size_info=True, **kwargs):
71897194
self.output.size_placeholder[axis_wob] + size_delta, 0, tf.shape(shifted)[axis])
71907195
from ..util.data import Dim
71917196
Dim(
7192-
kind=Dim.Types.Spatial, description="shift_axis",
7197+
kind=Dim.Types.Spatial, description="%s_shift_axis" % self.name,
71937198
dyn_size=new_size, batch=self.output.batch,
7194-
src_data=self.output, src_axis=axis)
7199+
src_data=self.output, src_axis=axis, auto_generated=True)
71957200
self.output.size_placeholder[axis_wob] = new_size
71967201

71977202
@classmethod
@@ -7210,7 +7215,7 @@ def get_out_data_from_opts(cls, name, amount, axis, pad, sources=(), **kwargs):
72107215
axis = out.get_axis_from_description(axis)
72117216
tag = out.dim_tags[axis]
72127217
dim = None if tag.dimension is None else max(0, tag.dimension - abs(amount))
7213-
tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim)
7218+
tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim, auto_generated=True)
72147219
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=tag)
72157220

72167221

@@ -7319,7 +7324,7 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwa
73197324
if out_dim:
73207325
assert out_dim.dimension == dim
73217326
else:
7322-
out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim)
7327+
out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim, auto_generated=True)
73237328
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)
73247329

73257330

@@ -7410,7 +7415,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs)
74107415
axis = out.get_axis_from_description(axis, allow_int=False)
74117416
in_dim = out.dim_tags[axis]
74127417
if not out_dim:
7413-
out_dim = Dim(kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim)
7418+
out_dim = Dim(
7419+
kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim, auto_generated=True)
74147420
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)
74157421

74167422

@@ -8582,18 +8588,17 @@ def get_out_data_from_opts(cls, name, network,
85828588
elif isinstance(d, int):
85838589
d = Dim(
85848590
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
8585-
description="%s:static:%i" % (name, i),
8591+
description="%s:static:%i" % (name, i), auto_generated=True,
85868592
dimension=d)
85878593
else:
85888594
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
85898595
dim_tags.append(d)
85908596
if add_time_axis:
85918597
dim_tags.insert(
8592-
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1))
8598+
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1, auto_generated=True))
85938599
if add_batch_axis:
85948600
dim_tags.insert(
8595-
0, Dim(
8596-
kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info()))
8601+
0, Dim(kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info()))
85978602
return Data(
85988603
name="%s_output" % name, dim_tags=dim_tags, dtype=dtype,
85998604
batch=network.get_global_batch_info() if add_batch_axis else None)

returnn/tf/layers/rec.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -3638,7 +3638,8 @@ def get_loop_acc_layer(name):
36383638
else:
36393639
time_dim_tag = Dim(
36403640
kind=Dim.Types.Spatial,
3641-
description="dyn-time:%s/%s" % (self.parent_rec_layer.get_full_ctx_name(), search_choices))
3641+
description="dyn-time:%s/%s" % (self.parent_rec_layer.get_full_ctx_name(), search_choices),
3642+
auto_generated=True)
36423643
elif is_out_time_dim:
36433644
self.time_dim_tag.declare_same_as(time_dim_tag)
36443645
output = (
@@ -6864,7 +6865,7 @@ def get_out_data_from_opts(cls, n_out, name, sources, **kwargs):
68646865
else:
68656866
time_dim = None
68666867
time_tag = Dim(
6867-
kind=Dim.Types.Spatial, description="%s_self_att_time" % name, dimension=time_dim)
6868+
kind=Dim.Types.Spatial, description="%s_self_att_time" % name, dimension=time_dim, auto_generated=True)
68686869
dim_tags = (batch_dim_tag, time_tag, feat_tag)
68696870
else:
68706871
dim_tags = (batch_dim_tag, feat_tag)
@@ -7242,7 +7243,7 @@ def get_out_data_from_opts(cls, name, sources,
72427243
vocab = Vocabulary(vocab_file=vocab_file, unknown_label=vocab_unknown_label)
72437244
tag = Dim(
72447245
kind=Dim.Types.Feature, description="%s_ken_lm_vocab" % name,
7245-
dimension=vocab.num_labels, vocab=vocab)
7246+
dimension=vocab.num_labels, vocab=vocab, auto_generated=True)
72467247
data = data.copy_add_dim_by_tag(tag, axis=-1, unbroadcast=True)
72477248
return data
72487249

@@ -7808,7 +7809,7 @@ def _create_template(cls, name, network, sources, masked_from, unit,
78087809
if not out_spatial_dim:
78097810
out_spatial_dim = Dim(
78107811
kind=Dim.Types.Spatial, description="%s:masked:time" % name,
7811-
derived_from_tag=source_data.get_time_dim_tag())
7812+
derived_from_tag=source_data.get_time_dim_tag(), auto_generated=True)
78127813
source_data = source_data.copy_template_replace_dim_tag(
78137814
axis=0,
78147815
new_dim_tag=out_spatial_dim)
@@ -9225,17 +9226,17 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
92259226
data = get_concat_sources_data_template(sources, name="%s_output" % name)
92269227
# The result will be without batch dim.
92279228
feature_dim_tag = Dim(
9228-
kind=Dim.Types.Feature, description="%s_rel_pos_enc_feat" % name, dimension=n_out)
9229+
kind=Dim.Types.Feature, description="%s_rel_pos_enc_feat" % name, dimension=n_out, auto_generated=True)
92299230
if data.have_time_axis():
92309231
time_dim_tag = data.get_time_dim_tag()
92319232
# TODO using same dim tag twice will not be supported at some future point...
92329233
data = data.copy_template_new_dim_tags((time_dim_tag, time_dim_tag, feature_dim_tag))
92339234
else:
92349235
# length will be ``network.get_rec_step_index() + 1``.
92359236
dummy_dim_tag = Dim(
9236-
kind=Dim.Types.Spatial, description="%s_rel_pos_enc_dummy" % name, dimension=1)
9237+
kind=Dim.Types.Spatial, description="%s_rel_pos_enc_dummy" % name, dimension=1, auto_generated=True)
92379238
time_dim_tag = Dim(
9238-
kind=Dim.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None)
9239+
kind=Dim.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None, auto_generated=True)
92399240
data = data.copy_template_new_dim_tags((dummy_dim_tag, time_dim_tag, feature_dim_tag))
92409241
return data
92419242

returnn/tf/layers/signal_processing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,9 @@ def _compute_size_placeholder():
457457
new_size = nr_of_full_frames + nf_of_paded_frames
458458
from ..util.data import Dim
459459
Dim(
460-
kind=Dim.Types.Spatial, description="MultiChannelMultiResolutionStft",
460+
kind=Dim.Types.Spatial, description="%s:MultiChannelMultiResolutionStft" % self.name,
461461
dyn_size=new_size, batch=self.output.batch,
462-
src_data=self.output, src_axis=self.output.get_batch_axis(0))
462+
src_data=self.output, src_axis=self.output.get_batch_axis(0), auto_generated=True)
463463
size_placeholder_dict[0] = new_size
464464
return size_placeholder_dict
465465

returnn/tf/util/data.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
5858
vocab=None,
5959
dyn_size=None, dyn_size_ext=None,
6060
undefined=False, generic=False, special=False,
61+
auto_generated=False,
6162
match_priority=0,
6263
derived_from_tag=None, derived_from_op=None,
6364
batch=None, control_flow_ctx=None,
@@ -75,6 +76,9 @@ def __init__(self, kind=Types.Unspecified, description=None,
7576
:param bool special: Like `generic`, this can not be a dim tag of :class:`Data`.
7677
But this dim tag also does not match anything except itself.
7778
So it can be used to represent special placeholders with special meanings like ``single_step``.
79+
:param bool auto_generated: This is auto-generated by RETURNN because it was not explicitly specified by the user.
80+
E.g. for ConvLayer and others. This implies certain behavior on equality, such as comparing the description,
81+
to allow for several independent creations of the dim tag during template construction.
7882
:param Dim|None derived_from_tag:
7983
Whether this new tag is reduced, down/up sampled, padded etc from this given other tag.
8084
In situations where dim tags are being matched (Data.get_common_data),
@@ -121,6 +125,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
121125
self._undefined = undefined
122126
self.generic = generic
123127
self.special = special
128+
self.auto_generated = auto_generated
124129
# We can have different tag variants per batch info (e.g. with beam), or per control flow ctx.
125130
# They each have same_as = self. The same_base should have the base (global) batch info.
126131
self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],Dim] # nopep8
@@ -739,7 +744,7 @@ def is_equal(self, other, ignore_feature_dim=False, allow_same_feature_dim=False
739744
# We currently use the description because the identity would not be the same
740745
# in case of template construction where a dim tag is once created for a template layer,
741746
# and then later again for the real layer.
742-
if self.description == other.description:
747+
if self.auto_generated and other.auto_generated and self.description == other.description:
743748
return True
744749
return False
745750

@@ -781,7 +786,9 @@ def __hash__(self):
781786
return hash(base)
782787
if self.derived_from_op:
783788
return hash(self.derived_from_op)
784-
return hash((base.kind, base.dimension, base.description))
789+
if self.auto_generated:
790+
return hash((base.kind, base.dimension, base.description))
791+
return hash(id(base))
785792

786793
def get_same_base(self):
787794
"""
@@ -1168,7 +1175,8 @@ def _make_constant_static_dim(cls, value, kind=None):
11681175
dimension=value,
11691176
kind=kind or Dim.Types.Unspecified,
11701177
description="unnamed_%sdim_%i" % (kind.name + "_" if kind else "", value),
1171-
derived_from_op=Dim.Op(kind="constant", inputs=[], attribs={"value": value}))
1178+
derived_from_op=Dim.Op(kind="constant", inputs=[], attribs={"value": value}),
1179+
auto_generated=True)
11721180

11731181
def _is_constant_static_dim(self):
11741182
return self.derived_from_op and self.derived_from_op.kind == "constant"

0 commit comments

Comments
 (0)