Skip to content

Commit e46b559

Browse files
authored
Dim get_size_dim_tag fix (#1148)
Also fixes Dim.is_dynamic. Related: #1139 Also needed for #1143.
1 parent 355c9bb commit e46b559

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

returnn/tf/layers/rec.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1886,9 +1886,7 @@ def get_layer(name):
18861886
prev_layer = prev_layers[layer_name]
18871887
assert layer.output.batch_shape == prev_layer.output.batch_shape
18881888
assert layer.output.batch_dim_axis == prev_layer.output.batch_dim_axis
1889-
assert sorted(layer.output.size_placeholder.keys()) == sorted(prev_layer.output.size_placeholder.keys())
1890-
for i in range(len(layer.output.size_placeholder)):
1891-
assert layer.output.get_size_dim_tag(i) == prev_layer.output.get_size_dim_tag(i)
1889+
assert layer.output.get_dyn_size_tags() == prev_layer.output.get_dyn_size_tags()
18921890

18931891
def get_prev_template_layer(self, layer_name):
18941892
"""

returnn/tf/util/data.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def is_dynamic(self):
518518
:return: whether the dim is not static. usually means that it has seq lengths
519519
:rtype: bool
520520
"""
521-
return self.dimension is not None
521+
return self.dimension is None and not self.is_batch_dim()
522522

523523
def can_be_used_as_dim(self):
524524
"""
@@ -5412,13 +5412,20 @@ def get_time_dim_tag(self):
54125412
assert self.time_dim_axis is not None
54135413
return self.get_dim_tag(self.time_dim_axis)
54145414

5415+
def get_dyn_size_tags(self):
5416+
"""
5417+
:return: all dim tags with dynamic size
5418+
:rtype: list[Dim]
5419+
"""
5420+
return [dim_tag for dim_tag in self._dim_tags if dim_tag.is_dynamic()]
5421+
54155422
def get_size_dim_tag(self, number):
54165423
"""
54175424
:param int number: index in sorted(size_placeholder.keys())
54185425
:rtype: Dim
54195426
"""
5420-
axis_wo_batch = sorted(self.size_placeholder.keys())[number]
5421-
return self.get_dim_tag(self.get_batch_axis(axis_wo_batch))
5427+
dyn_size_tags = self.get_dyn_size_tags()
5428+
return dyn_size_tags[number]
54225429

54235430
def get_batch_shape_dim_tags(self):
54245431
"""

0 commit comments

Comments
 (0)