File tree 2 files changed +11
-6
lines changed
2 files changed +11
-6
lines changed Original file line number Diff line number Diff line change @@ -1886,9 +1886,7 @@ def get_layer(name):
1886
1886
prev_layer = prev_layers [layer_name ]
1887
1887
assert layer .output .batch_shape == prev_layer .output .batch_shape
1888
1888
assert layer .output .batch_dim_axis == prev_layer .output .batch_dim_axis
1889
- assert sorted (layer .output .size_placeholder .keys ()) == sorted (prev_layer .output .size_placeholder .keys ())
1890
- for i in range (len (layer .output .size_placeholder )):
1891
- assert layer .output .get_size_dim_tag (i ) == prev_layer .output .get_size_dim_tag (i )
1889
+ assert layer .output .get_dyn_size_tags () == prev_layer .output .get_dyn_size_tags ()
1892
1890
1893
1891
def get_prev_template_layer (self , layer_name ):
1894
1892
"""
Original file line number Diff line number Diff line change @@ -518,7 +518,7 @@ def is_dynamic(self):
518
518
:return: whether the dim is not static. usually means that it has seq lengths
519
519
:rtype: bool
520
520
"""
521
- return self .dimension is not None
521
+ return self .dimension is None and not self . is_batch_dim ()
522
522
523
523
def can_be_used_as_dim (self ):
524
524
"""
@@ -5412,13 +5412,20 @@ def get_time_dim_tag(self):
5412
5412
assert self .time_dim_axis is not None
5413
5413
return self .get_dim_tag (self .time_dim_axis )
5414
5414
5415
+ def get_dyn_size_tags (self ):
5416
+ """
5417
+ :return: all dim tags with dynamic size
5418
+ :rtype: list[Dim]
5419
+ """
5420
+ return [dim_tag for dim_tag in self ._dim_tags if dim_tag .is_dynamic ()]
5421
+
5415
5422
def get_size_dim_tag (self , number ):
5416
5423
"""
5417
5424
:param int number: index in sorted(size_placeholder.keys())
5418
5425
:rtype: Dim
5419
5426
"""
5420
- axis_wo_batch = sorted ( self .size_placeholder . keys ())[ number ]
5421
- return self . get_dim_tag ( self . get_batch_axis ( axis_wo_batch ))
5427
+ dyn_size_tags = self .get_dyn_size_tags ()
5428
+ return dyn_size_tags [ number ]
5422
5429
5423
5430
def get_batch_shape_dim_tags (self ):
5424
5431
"""
You can’t perform that action at this time.
0 commit comments