@@ -58,6 +58,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
58
58
vocab = None ,
59
59
dyn_size = None , dyn_size_ext = None ,
60
60
undefined = False , generic = False , special = False ,
61
+ match_priority = 0 ,
61
62
derived_from_tag = None , derived_from_op = None ,
62
63
batch = None , control_flow_ctx = None ,
63
64
src_data = None , src_axis = None ):
@@ -80,6 +81,10 @@ def __init__(self, kind=Types.Unspecified, description=None,
80
81
the behavior is to consider them as equal,
81
82
and assume that the chain of operations (e.g. padding + valid conv) results in the same dim.
82
83
:param Dim.Op|None derived_from_op:
84
+ :param int match_priority: when there is ambiguity between multiple dim tags, this value defines the order
85
+ in which the dimension are assigned to their matching counterparts.
86
+ A dimension tag with a higher priority value is assigned first.
87
+ E.g. for a square matrix used for a linear transformation, the reduce dim tag should have a higher priority.
83
88
:param BatchInfo|None batch: for batch-dim, or dynamic dims per batch
84
89
:param ControlFlowContext|None control_flow_ctx:
85
90
:param Data|None src_data:
@@ -98,6 +103,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
98
103
self .derived_from_op = derived_from_op
99
104
if derived_from_op and not derived_from_op .output :
100
105
derived_from_op .output = self
106
+ self .match_priority = match_priority
101
107
if src_data :
102
108
assert isinstance (src_data , Data ) and isinstance (src_axis , int )
103
109
if not batch and dyn_size_ext :
@@ -189,11 +195,12 @@ def __deepcopy__(self, memo=None):
189
195
"""
190
196
return self
191
197
192
- def copy (self , same_as_self , description = None , kind = None ):
198
+ def copy (self , same_as_self = True , description = None , kind = None , match_priority = None ):
193
199
"""
194
200
:param bool same_as_self:
195
201
:param str|None description: new description
196
202
:param Entity|None kind: if set, overwrites self.kind
203
+ :param int|None match_priority:
197
204
:return: copy, maybe as new kind. setting same_as to self
198
205
:rtype: Dim
199
206
"""
@@ -202,6 +209,7 @@ def copy(self, same_as_self, description=None, kind=None):
202
209
assert description is not None , "%s copy with not same_as_self should have a new description" % self
203
210
tag = Dim (
204
211
kind = kind or self .kind , description = description or self .description ,
212
+ match_priority = match_priority if match_priority is not None else self .match_priority ,
205
213
dimension = self .dimension , dyn_size_ext = self .dyn_size_ext ,
206
214
batch = self .batch ,
207
215
src_data = self .src_data , src_axis = self .src_axis )
@@ -4534,7 +4542,12 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
4534
4542
# Once we have not guaranteed unique dim tags, multiple axes could match.
4535
4543
# https://github.com/rwth-i6/returnn/issues/632
4536
4544
dims = [i for (i , tag ) in enumerate (self .dim_tags ) if tag == axes ]
4537
- assert len (dims ) <= 1 , "%s: matching dim %s must be unique" % (self , axes )
4545
+ if len (dims ) > 1 :
4546
+ max_match_priority = max (self .dim_tags [i ].match_priority for i in dims )
4547
+ dims = [i for i in dims if self .dim_tags [i ].match_priority == max_match_priority ]
4548
+ assert len (dims ) <= 1 , (
4549
+ "%s: matching dim %s must be unique,"
4550
+ " use `match_priority` to resolve the matching order of ambiguous dimensions" % (self , axes ))
4538
4551
return dims
4539
4552
if isinstance (axes , int ):
4540
4553
self ._verify_axis_int_from_description (allow_int = allow_int )
0 commit comments