Skip to content

Commit ac02f4b

Browse files
committed
Dim auto_generated flag
Allows for better Dim is_equal which does not rely on the description. #634
1 parent cb4f182 commit ac02f4b

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

returnn/tf/util/data.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def copy(self, same_as_self=True, description=None, kind=None, match_priority=No
214214
assert description is not None, "%s copy with not same_as_self should have a new description" % self
215215
tag = Dim(
216216
kind=kind or self.kind, description=description or self.description,
217+
auto_generated=self.auto_generated,
217218
match_priority=match_priority if match_priority is not None else self.match_priority,
218219
dimension=self.dimension, dyn_size_ext=self.dyn_size_ext,
219220
batch=self.batch,
@@ -380,6 +381,7 @@ def get_for_batch_ctx(self, batch, ctx, allow_none=False):
380381
return None
381382
dim_tag = Dim(
382383
kind=self.kind, description=self.description, dimension=self.dimension,
384+
auto_generated=self.auto_generated,
383385
batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx,
384386
dyn_size_ext=dyn_size_ext)
385387
dim_tag.same_as = same_base
@@ -2743,7 +2745,7 @@ def template_from_constant(cls, x, name, dtype=None, shape=None, with_batch_dim=
27432745
assert d == d_
27442746
d = Dim(
27452747
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
2746-
description="%s:static:%i" % (name, i),
2748+
description="%s:static:%i" % (name, i), auto_generated=True,
27472749
dimension=d)
27482750
else:
27492751
raise TypeError("%r shape[%i] invalid type %r in shape %r" % (name, i, type(d), shape))
@@ -3409,7 +3411,8 @@ def copy_add_dim_by_tag(self, dim_tag, unbroadcast=False, axis=None):
34093411
dim_tag = dim_tag.copy(same_as_self=True, kind=Dim.Types.Spatial)
34103412
if not unbroadcast and dim_tag.dimension != 1:
34113413
dim_tag = Dim(
3412-
kind=dim_tag.kind, description="%s_dummy_dim1" % (dim_tag.description or "unnamed"), dimension=1)
3414+
kind=dim_tag.kind, description="%s_dummy_dim1" % (dim_tag.description or "unnamed"), dimension=1,
3415+
auto_generated=True)
34133416
data_opts["dim_tags"] = self.dim_tags[:axis] + (dim_tag,) + self.dim_tags[axis:]
34143417
other_special_axes = self.get_special_axes_dict(counted_with_batch_dim=True, only_available=True)
34153418
for k, a in other_special_axes.items():
@@ -3444,11 +3447,11 @@ def copy_split_feature_dim(self, new_feature_dim):
34443447
new_feature_dim_axis = self.feature_dim_axis + 1
34453448
data_opts = self.get_kwargs(include_special_axes=False)
34463449
dim_tag_split_rem = Dim(
3447-
kind=Dim.Types.Spatial, description="feature_split_rem_%i" % feature_dim_rem,
3450+
kind=Dim.Types.Spatial, description="feature_split_rem_%i" % feature_dim_rem, auto_generated=True,
34483451
dimension=feature_dim_rem)
34493452
dim_tag_new = Dim(
34503453
kind=self.dim_tags[self.feature_dim_axis].kind,
3451-
description="feature_split_new_%i" % new_feature_dim,
3454+
description="feature_split_new_%i" % new_feature_dim, auto_generated=True,
34523455
dimension=new_feature_dim)
34533456
dim_tags = (
34543457
self.dim_tags[:self.feature_dim_axis] +
@@ -3614,7 +3617,7 @@ def copy_time_flattened(self):
36143617
data_opts["placeholder"] = self.get_placeholder_time_flattened()
36153618
dim_tag = self.dim_tags[self.time_dim_axis]
36163619
dim_tag = Dim(
3617-
kind=Dim.Types.Spatial, description="%s_flattened" % (dim_tag.description or "unnamed"))
3620+
kind=Dim.Types.Spatial, description="%s_flattened" % (dim_tag.description or "unnamed"), auto_generated=True)
36183621
data_opts["dim_tags"] = (
36193622
(dim_tag,) +
36203623
tuple(tag for (i, tag) in enumerate(self.dim_tags) if i not in (self.batch_dim_axis, self.time_dim_axis)))
@@ -3818,7 +3821,7 @@ def copy_template_adding_time_dim(self, name=None, time_dim_axis=0):
38183821
assert time_dim_axis >= 0
38193822
assert 0 <= time_dim_axis <= self.batch_ndim
38203823
kwargs = self.get_kwargs(include_special_axes=False)
3821-
dim_tag = Dim(kind=Dim.Types.Time, description="unknown_time", dimension=None)
3824+
dim_tag = Dim(kind=Dim.Types.Time, description="unknown_time", dimension=None, auto_generated=True)
38223825
dim_tags = self.dim_tags[:time_dim_axis] + (dim_tag,) + self.dim_tags[time_dim_axis:]
38233826
kwargs["dim_tags"] = dim_tags
38243827
other_special_axes = self.get_special_axes_dict(counted_with_batch_dim=True, only_available=True)
@@ -3872,6 +3875,7 @@ def copy_template_replace_dim(self, axis, new_dim, new_size=None):
38723875
return self.copy_template() # nothing to do
38733876
dim_tag = Dim(
38743877
kind=dim_tag.kind, description="%s_replaced" % (dim_tag.description or "unnamed"),
3878+
auto_generated=True,
38753879
dimension=new_dim, dyn_size=new_size)
38763880
return self.copy_template_replace_dim_tag(axis=axis, new_dim_tag=dim_tag)
38773881

@@ -5572,7 +5576,7 @@ def _infer_dim_tags_tuple_from_shape(
55725576
if axis == feature_dim_axis and dyn_size is None and axis != time_dim_axis:
55735577
tag = Dim(
55745578
kind=Dim.Types.Feature, dimension=dim, description="feature:%s" % name,
5575-
undefined=dim is None)
5579+
undefined=dim is None, auto_generated=True)
55765580
else:
55775581
assert axis in spatial_axes
55785582
description = "time" if axis == time_dim_axis else "spatial%i" % spatial_axes.index(axis)
@@ -5587,7 +5591,7 @@ def _infer_dim_tags_tuple_from_shape(
55875591
description += ":%s" % name
55885592
tag = Dim(
55895593
kind=Dim.Types.Spatial, description=description, dimension=dim, dyn_size=dyn_size,
5590-
undefined=dim is None and dyn_size is None)
5594+
undefined=dim is None and dyn_size is None, auto_generated=True)
55915595
dim_tags[axis] = tag
55925596
assert sorted(dim_tags.keys()) == list(range(len(batch_shape)))
55935597
return tuple(dim_tags[axis] for axis in range(len(batch_shape)))

0 commit comments

Comments
 (0)