@@ -792,6 +792,59 @@ def vocab(self, vocab):
792
792
self .get_same_base ()._vocab = vocab
793
793
794
794
795
+ # Global dim tag placeholders.
796
+ BatchDim = DimensionTag (kind = DimensionTag .Types .Batch , description = "global batch" )
797
+
798
+
799
+ class _ImplicitDim :
800
+ """
801
+ Represents an implicit dim (dim tag) in :class:`Data`.
802
+ https://github.com/rwth-i6/returnn/issues/706
803
+ """
804
+ def __init__ (self , tag ):
805
+ """
806
+ :param DimensionTag tag:
807
+ """
808
+ self .tag = tag
809
+
810
+ def __repr__ (self ):
811
+ return "%s(%r)" % (self .__class__ .__name__ , self .tag )
812
+
813
+ def _eq_tuple (self ):
814
+ return self .__class__ , self .tag
815
+
816
+ def __hash__ (self ):
817
+ return hash (self ._eq_tuple ())
818
+
819
+ def __eq__ (self , other ):
820
+ if isinstance (other , _ImplicitDim ):
821
+ return self ._eq_tuple () == other ._eq_tuple ()
822
+ return False
823
+
824
+ def __ne__ (self , other ):
825
+ return not (self == other )
826
+
827
+
828
+ class ImplicitSparseDim (_ImplicitDim ):
829
+ """
830
+ Represents an implicit dim via Data.sparse_dim.
831
+ """
832
+
833
+
834
+ class ImplicitDynSizeDim (_ImplicitDim ):
835
+ """
836
+ Represents an implicit dim via dynamic dim sizes.
837
+ https://github.com/rwth-i6/returnn/issues/706
838
+ (For example via :class:`CumConcatLayer`.)
839
+ """
840
+
841
+
842
+ class VerifyOutShapeException (Exception ):
843
+ """
844
+ Exception via :func:`Data.verify_out_shape`.
845
+ """
846
+
847
+
795
848
class BatchInfo :
796
849
"""
797
850
A batched tensor is a tensor with batch dimension,
@@ -1793,6 +1846,51 @@ def get_runtime_sanity_check_op(self):
1793
1846
checks += [dyn_size_ext .get_runtime_sanity_check_op ()]
1794
1847
return tf .group (* checks )
1795
1848
1849
+ def verify_out_shape (self , out_shape ):
1850
+ """
1851
+ Verifies that ``out_shape`` matches our shape, i.e. specifically the dim tags.
1852
+ https://github.com/rwth-i6/returnn/issues/706
1853
+ Throws an exception if this is not the case.
1854
+
1855
+ :param set[DimensionTag|_ImplicitDim]|tuple|list out_shape:
1856
+ It must be a set, with the only exception when it is empty (then it doesn't matter).
1857
+ See :func:`dim_tags_set`.
1858
+ """
1859
+ self_dim_tags = self .dim_tags_set_implicit
1860
+ self_dim_tags_implicit_only = self .dim_tags_set_implicit_only_wrapped
1861
+ if not out_shape :
1862
+ if self_dim_tags :
1863
+ raise VerifyOutShapeException (
1864
+ "%s verify_out_shape, with dims %s, does not match empty out_shape %r" % (self , self_dim_tags , out_shape ))
1865
+ return
1866
+ if not isinstance (out_shape , set ):
1867
+ raise TypeError ("%s verify_out_shape: expects a set but got %s" % (self , type (out_shape )))
1868
+ remaining = set (self_dim_tags )
1869
+ for dim in out_shape :
1870
+ if isinstance (dim , DimensionTag ):
1871
+ dim_tag = dim
1872
+ elif isinstance (dim , _ImplicitDim ):
1873
+ dim_tag = dim .tag
1874
+ if dim not in self_dim_tags_implicit_only :
1875
+ raise VerifyOutShapeException (
1876
+ "%s verify_out_shape, with dims %s, with out_shape %s, %s is not an implicit dim in self" % (
1877
+ self , self_dim_tags , out_shape , dim ))
1878
+ else :
1879
+ raise TypeError ("%s verify_out_shape with out_shape %s: expect dim tags but got %s" % (
1880
+ self , out_shape , type (dim )))
1881
+ if dim_tag not in remaining :
1882
+ if dim_tag in self_dim_tags : # can happen e.g. if specified once as implicit dim and then also as explicit
1883
+ raise VerifyOutShapeException (
1884
+ "%s verify_out_shape, with dims %s, does not match out_shape %r, dim %s multiple times in out_shape" % (
1885
+ self , self_dim_tags , out_shape , dim ))
1886
+ raise VerifyOutShapeException (
1887
+ "%s verify_out_shape, with dims %s, does not match out_shape %r, %s not in self" % (
1888
+ self , self_dim_tags , out_shape , dim ))
1889
+ remaining .discard (dim_tag )
1890
+ if remaining :
1891
+ raise VerifyOutShapeException (
1892
+ "%s verify_out_shape, dims %s are not specified in out_shape %s" % (self , remaining , out_shape ))
1893
+
1796
1894
def get_placeholder_kwargs (self , with_batch = True ):
1797
1895
"""
1798
1896
:param bool with_batch:
@@ -2860,6 +2958,53 @@ def dim_tags_sparse(self):
2860
2958
return self .dim_tags
2861
2959
return self .dim_tags [:self .feature_dim_axis ] + self .dim_tags [self .feature_dim_axis + 1 :]
2862
2960
2961
+ @property
2962
+ def dim_tags_set_implicit_only_wrapped (self ):
2963
+ """
2964
+ :return: Dim tags implicit by sparse dim, or dynamic sizes, and not present as explicit dims.
2965
+ Also see :func:`dim_tags_set`.
2966
+ :rtype: set[_ImplicitDim]
2967
+ """
2968
+ self_dim_tags = set (self .dim_tags )
2969
+ dims = set ()
2970
+ if self .sparse_dim and self .sparse_dim not in self_dim_tags :
2971
+ dims .add (ImplicitSparseDim (self .sparse_dim ))
2972
+ for dim in self .dim_tags :
2973
+ if dim .dyn_size_ext :
2974
+ for dim_ in dim .dyn_size_ext .dim_tags :
2975
+ if dim_ not in self_dim_tags :
2976
+ dims .add (ImplicitDynSizeDim (dim_ ))
2977
+ return dims
2978
+
2979
+ @property
2980
+ def dim_tags_set_implicit_only (self ):
2981
+ """
2982
+ :return: Dim tags implicit by sparse dim, or dynamic sizes, and not present as explicit dims.
2983
+ Also see :func:`dim_tags_set`.
2984
+ :rtype: set[DimensionTag]
2985
+ """
2986
+ return set (dim .tag for dim in self .dim_tags_set_implicit_only_wrapped )
2987
+
2988
+ @property
2989
+ def dim_tags_set_implicit (self ):
2990
+ """
2991
+ This is mostly intended to be used for verification, such as ``out_shape`` in a layer.
2992
+ https://github.com/rwth-i6/returnn/issues/706
2993
+
2994
+ We return a set because when dim tags (dimensions, and the shape) are checked,
2995
+ we never want that the order plays any role.
2996
+ https://github.com/rwth-i6/returnn/wiki/RETURNN-principles
2997
+ Further, dimension tags should ideally be unique.
2998
+ https://github.com/rwth-i6/returnn/issues/632
2999
+ (This is not enforced currently, but we should not treat this specially now.)
3000
+
3001
+ :return: set of dim tags
3002
+ :rtype: set[DimensionTag]
3003
+ """
3004
+ dims = set (self .dim_tags )
3005
+ dims .update (self .dim_tags_set_implicit_only )
3006
+ return dims
3007
+
2863
3008
@property
2864
3009
def ndim (self ):
2865
3010
"""
0 commit comments