Skip to content

Commit a10e882

Browse files
authored
Dim match_priority (#871)
Solution for ambiguous dim tags, e.g. in VariableLayer for square matrix. Via: rwth-i6/returnn_common#17 (comment)
1 parent 89667ef commit a10e882

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

returnn/tf/util/data.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
5858
vocab=None,
5959
dyn_size=None, dyn_size_ext=None,
6060
undefined=False, generic=False, special=False,
61+
match_priority=0,
6162
derived_from_tag=None, derived_from_op=None,
6263
batch=None, control_flow_ctx=None,
6364
src_data=None, src_axis=None):
@@ -80,6 +81,10 @@ def __init__(self, kind=Types.Unspecified, description=None,
8081
the behavior is to consider them as equal,
8182
and assume that the chain of operations (e.g. padding + valid conv) results in the same dim.
8283
:param Dim.Op|None derived_from_op:
84+
:param int match_priority: when there is ambiguity between multiple dim tags, this value defines the order
85+
in which the dimension are assigned to their matching counterparts.
86+
A dimension tag with a higher priority value is assigned first.
87+
E.g. for a square matrix used for a linear transformation, the reduce dim tag should have a higher priority.
8388
:param BatchInfo|None batch: for batch-dim, or dynamic dims per batch
8489
:param ControlFlowContext|None control_flow_ctx:
8590
:param Data|None src_data:
@@ -98,6 +103,7 @@ def __init__(self, kind=Types.Unspecified, description=None,
98103
self.derived_from_op = derived_from_op
99104
if derived_from_op and not derived_from_op.output:
100105
derived_from_op.output = self
106+
self.match_priority = match_priority
101107
if src_data:
102108
assert isinstance(src_data, Data) and isinstance(src_axis, int)
103109
if not batch and dyn_size_ext:
@@ -189,11 +195,12 @@ def __deepcopy__(self, memo=None):
189195
"""
190196
return self
191197

192-
def copy(self, same_as_self, description=None, kind=None):
198+
def copy(self, same_as_self=True, description=None, kind=None, match_priority=None):
193199
"""
194200
:param bool same_as_self:
195201
:param str|None description: new description
196202
:param Entity|None kind: if set, overwrites self.kind
203+
:param int|None match_priority:
197204
:return: copy, maybe as new kind. setting same_as to self
198205
:rtype: Dim
199206
"""
@@ -202,6 +209,7 @@ def copy(self, same_as_self, description=None, kind=None):
202209
assert description is not None, "%s copy with not same_as_self should have a new description" % self
203210
tag = Dim(
204211
kind=kind or self.kind, description=description or self.description,
212+
match_priority=match_priority if match_priority is not None else self.match_priority,
205213
dimension=self.dimension, dyn_size_ext=self.dyn_size_ext,
206214
batch=self.batch,
207215
src_data=self.src_data, src_axis=self.src_axis)
@@ -4534,7 +4542,12 @@ def get_axes_from_description(self, axes, allow_int=NotSpecified):
45344542
# Once we have not guaranteed unique dim tags, multiple axes could match.
45354543
# https://github.com/rwth-i6/returnn/issues/632
45364544
dims = [i for (i, tag) in enumerate(self.dim_tags) if tag == axes]
4537-
assert len(dims) <= 1, "%s: matching dim %s must be unique" % (self, axes)
4545+
if len(dims) > 1:
4546+
max_match_priority = max(self.dim_tags[i].match_priority for i in dims)
4547+
dims = [i for i in dims if self.dim_tags[i].match_priority == max_match_priority]
4548+
assert len(dims) <= 1, (
4549+
"%s: matching dim %s must be unique,"
4550+
" use `match_priority` to resolve the matching order of ambiguous dimensions" % (self, axes))
45384551
return dims
45394552
if isinstance(axes, int):
45404553
self._verify_axis_int_from_description(allow_int=allow_int)

tests/test_TFNetworkLayer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4591,6 +4591,41 @@ def test_DotLayer2():
45914591
assert_equal(out.shape, (S1, S2, B, V))
45924592

45934593

4594+
def test_DotLayer_linear_square_matrix():
4595+
from returnn.tf.util.data import batch_dim
4596+
time_dim = SpatialDim("time")
4597+
feat_dim = FeatureDim("feature", dimension=3)
4598+
config = Config({
4599+
"extern_data": {
4600+
"data": {"dim_tags": [batch_dim, time_dim, feat_dim]},
4601+
"matrix_ambiguous": {"dim_tags": [feat_dim, feat_dim], "available_for_inference": True},
4602+
"matrix_non_ambiguous": {
4603+
"dim_tags": [feat_dim.copy(match_priority=1), feat_dim], "available_for_inference": True},
4604+
},
4605+
})
4606+
with make_scope() as session:
4607+
net = TFNetwork(config=config)
4608+
try:
4609+
net.construct_from_dict({
4610+
"output": {
4611+
"class": "dot", "from": ["data:data", "data:matrix_ambiguous"], "reduce": feat_dim
4612+
},
4613+
})
4614+
except Exception as exc:
4615+
print("Expected exception: %r" % exc)
4616+
assert "must be unique" in str(exc)
4617+
else:
4618+
raise Exception("Expected exception but constructed layer: %s" % net.get_default_output_layer())
4619+
net.construct_from_dict({
4620+
"output": {
4621+
"class": "dot", "from": ["data:data", "data:matrix_non_ambiguous"], "reduce": feat_dim
4622+
},
4623+
})
4624+
out = net.get_default_output_layer().output
4625+
assert out.dim_tags == (batch_dim, time_dim, feat_dim)
4626+
session.run(out.placeholder, feed_dict=make_feed_dict(net.extern_data))
4627+
4628+
45944629
def test_DotLayer_mask_dyn_seq():
45954630
batch = Dim(kind=Dim.Types.Batch, description="batch")
45964631
time = SpatialDim("time")

0 commit comments

Comments
 (0)