Skip to content

Commit 81525df

Browse files
committed
Data.dim_tags_set, Data.verify_out_shape
#706
1 parent 6228964 commit 81525df

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed

returnn/tf/util/data.py

+145
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,59 @@ def vocab(self, vocab):
792792
self.get_same_base()._vocab = vocab
793793

794794

795+
# Global dim tag placeholders.
796+
BatchDim = DimensionTag(kind=DimensionTag.Types.Batch, description="global batch")
797+
798+
799+
class _ImplicitDim:
800+
"""
801+
Represents an implicit dim (dim tag) in :class:`Data`.
802+
https://github.com/rwth-i6/returnn/issues/706
803+
"""
804+
def __init__(self, tag):
805+
"""
806+
:param DimensionTag tag:
807+
"""
808+
self.tag = tag
809+
810+
def __repr__(self):
811+
return "%s(%r)" % (self.__class__.__name__, self.tag)
812+
813+
def _eq_tuple(self):
814+
return self.__class__, self.tag
815+
816+
def __hash__(self):
817+
return hash(self._eq_tuple())
818+
819+
def __eq__(self, other):
820+
if isinstance(other, _ImplicitDim):
821+
return self._eq_tuple() == other._eq_tuple()
822+
return False
823+
824+
def __ne__(self, other):
825+
return not (self == other)
826+
827+
828+
class ImplicitSparseDim(_ImplicitDim):
829+
"""
830+
Represents an implicit dim via Data.sparse_dim.
831+
"""
832+
833+
834+
class ImplicitDynSizeDim(_ImplicitDim):
835+
"""
836+
Represents an implicit dim via dynamic dim sizes.
837+
https://github.com/rwth-i6/returnn/issues/706
838+
(For example via :class:`CumConcatLayer`.)
839+
"""
840+
841+
842+
class VerifyOutShapeException(Exception):
843+
"""
844+
Exception via :func:`Data.verify_out_shape`.
845+
"""
846+
847+
795848
class BatchInfo:
796849
"""
797850
A batched tensor is a tensor with batch dimension,
@@ -1793,6 +1846,51 @@ def get_runtime_sanity_check_op(self):
17931846
checks += [dyn_size_ext.get_runtime_sanity_check_op()]
17941847
return tf.group(*checks)
17951848

1849+
def verify_out_shape(self, out_shape):
1850+
"""
1851+
Verifies that ``out_shape`` matches our shape, i.e. specifically the dim tags.
1852+
https://github.com/rwth-i6/returnn/issues/706
1853+
Throws an exception if this is not the case.
1854+
1855+
:param set[DimensionTag|_ImplicitDim]|tuple|list out_shape:
1856+
It must be a set, with the only exception when it is empty (then it doesn't matter).
1857+
See :func:`dim_tags_set`.
1858+
"""
1859+
self_dim_tags = self.dim_tags_set_implicit
1860+
self_dim_tags_implicit_only = self.dim_tags_set_implicit_only_wrapped
1861+
if not out_shape:
1862+
if self_dim_tags:
1863+
raise VerifyOutShapeException(
1864+
"%s verify_out_shape, with dims %s, does not match empty out_shape %r" % (self, self_dim_tags, out_shape))
1865+
return
1866+
if not isinstance(out_shape, set):
1867+
raise TypeError("%s verify_out_shape: expects a set but got %s" % (self, type(out_shape)))
1868+
remaining = set(self_dim_tags)
1869+
for dim in out_shape:
1870+
if isinstance(dim, DimensionTag):
1871+
dim_tag = dim
1872+
elif isinstance(dim, _ImplicitDim):
1873+
dim_tag = dim.tag
1874+
if dim not in self_dim_tags_implicit_only:
1875+
raise VerifyOutShapeException(
1876+
"%s verify_out_shape, with dims %s, with out_shape %s, %s is not an implicit dim in self" % (
1877+
self, self_dim_tags, out_shape, dim))
1878+
else:
1879+
raise TypeError("%s verify_out_shape with out_shape %s: expect dim tags but got %s" % (
1880+
self, out_shape, type(dim)))
1881+
if dim_tag not in remaining:
1882+
if dim_tag in self_dim_tags: # can happen e.g. if specified once as implicit dim and then also as explicit
1883+
raise VerifyOutShapeException(
1884+
"%s verify_out_shape, with dims %s, does not match out_shape %r, dim %s multiple times in out_shape" % (
1885+
self, self_dim_tags, out_shape, dim))
1886+
raise VerifyOutShapeException(
1887+
"%s verify_out_shape, with dims %s, does not match out_shape %r, %s not in self" % (
1888+
self, self_dim_tags, out_shape, dim))
1889+
remaining.discard(dim_tag)
1890+
if remaining:
1891+
raise VerifyOutShapeException(
1892+
"%s verify_out_shape, dims %s are not specified in out_shape %s" % (self, remaining, out_shape))
1893+
17961894
def get_placeholder_kwargs(self, with_batch=True):
17971895
"""
17981896
:param bool with_batch:
@@ -2860,6 +2958,53 @@ def dim_tags_sparse(self):
28602958
return self.dim_tags
28612959
return self.dim_tags[:self.feature_dim_axis] + self.dim_tags[self.feature_dim_axis + 1:]
28622960

2961+
@property
2962+
def dim_tags_set_implicit_only_wrapped(self):
2963+
"""
2964+
:return: Dim tags implicit by sparse dim, or dynamic sizes, and not present as explicit dims.
2965+
Also see :func:`dim_tags_set`.
2966+
:rtype: set[_ImplicitDim]
2967+
"""
2968+
self_dim_tags = set(self.dim_tags)
2969+
dims = set()
2970+
if self.sparse_dim and self.sparse_dim not in self_dim_tags:
2971+
dims.add(ImplicitSparseDim(self.sparse_dim))
2972+
for dim in self.dim_tags:
2973+
if dim.dyn_size_ext:
2974+
for dim_ in dim.dyn_size_ext.dim_tags:
2975+
if dim_ not in self_dim_tags:
2976+
dims.add(ImplicitDynSizeDim(dim_))
2977+
return dims
2978+
2979+
@property
2980+
def dim_tags_set_implicit_only(self):
2981+
"""
2982+
:return: Dim tags implicit by sparse dim, or dynamic sizes, and not present as explicit dims.
2983+
Also see :func:`dim_tags_set`.
2984+
:rtype: set[DimensionTag]
2985+
"""
2986+
return set(dim.tag for dim in self.dim_tags_set_implicit_only_wrapped)
2987+
2988+
@property
2989+
def dim_tags_set_implicit(self):
2990+
"""
2991+
This is mostly intended to be used for verification, such as ``out_shape`` in a layer.
2992+
https://github.com/rwth-i6/returnn/issues/706
2993+
2994+
We return a set because when dim tags (dimensions, and the shape) are checked,
2995+
we never want that the order plays any role.
2996+
https://github.com/rwth-i6/returnn/wiki/RETURNN-principles
2997+
Further, dimension tags should ideally be unique.
2998+
https://github.com/rwth-i6/returnn/issues/632
2999+
(This is not enforced currently, but we should not treat this specially now.)
3000+
3001+
:return: set of dim tags
3002+
:rtype: set[DimensionTag]
3003+
"""
3004+
dims = set(self.dim_tags)
3005+
dims.update(self.dim_tags_set_implicit_only)
3006+
return dims
3007+
28633008
@property
28643009
def ndim(self):
28653010
"""

0 commit comments

Comments
 (0)