@@ -7045,13 +7045,13 @@ class EditDistanceTableLayer(LayerBase):
7045
7045
layer_class = "edit_distance_table"
7046
7046
recurrent = True
7047
7047
7048
- def __init__ (self , debug = False , blank_idx = None , ** kwargs ):
7048
+ def __init__ (self , debug = False , blank_idx = None , out_dim = None , ** kwargs ):
7049
7049
"""
7050
7050
:param bool debug:
7051
7051
:param int|None blank_idx: if given, will keep the same row for this source label
7052
+ :param DimensionTag|None out_dim:
7052
7053
"""
7053
- from returnn .tf .util .basic import where_bc
7054
- super (EditDistanceTableLayer , self ).__init__ (** kwargs )
7054
+ super (EditDistanceTableLayer , self ).__init__ (out_dim = out_dim , ** kwargs )
7055
7055
assert len (self .sources ) == 1 , "%s: expects exactly a single source" % self
7056
7056
source_data = self .sources [0 ].output
7057
7057
assert source_data .dtype == "int32" and source_data .batch_ndim <= 2 and source_data .sparse
@@ -7076,8 +7076,7 @@ def __init__(self, debug=False, blank_idx=None, **kwargs):
7076
7076
mask_flag = rec_step_info .get_prev_end_flag (target_search_choices = self .get_search_choices ())
7077
7077
source = source_data .placeholder
7078
7078
if blank_idx is None :
7079
- from returnn .tf .util .basic import expand_dims_unbroadcast
7080
- source_len = expand_dims_unbroadcast (rec_step_info .step , axis = 0 , dim = batch_dim )
7079
+ source_len = tf_util .expand_dims_unbroadcast (rec_step_info .step , axis = 0 , dim = batch_dim )
7081
7080
else :
7082
7081
source_len = self ._rec_previous_layer .rec_vars_outputs ["source_len" ]
7083
7082
mask_flag = tf .logical_or (mask_flag , tf .equal (source , blank_idx ))
@@ -7088,9 +7087,8 @@ def __init__(self, debug=False, blank_idx=None, **kwargs):
7088
7087
a_ended = mask_flag ,
7089
7088
b = target_data .placeholder , b_len = target_data .get_sequence_lengths ())
7090
7089
if blank_idx is not None :
7091
- self .rec_vars_outputs ["source_len" ] = source_len + where_bc (mask_flag , 0 , 1 )
7090
+ self .rec_vars_outputs ["source_len" ] = source_len + tf_util . where_bc (mask_flag , 0 , 1 )
7092
7091
if debug :
7093
- from returnn .tf .util .basic import py_print , vocab_idx_repr
7094
7092
print_out = [str (self )]
7095
7093
choice = self .get_search_choices ()
7096
7094
if choice :
@@ -7100,11 +7098,11 @@ def __init__(self, debug=False, blank_idx=None, **kwargs):
7100
7098
print_out += [
7101
7099
"a_n" , rec_step_info .step ,
7102
7100
"a_ended" , rec_step_info .get_prev_end_flag (target_search_choices = self .get_search_choices ()),
7103
- "a" , vocab_idx_repr (source_data .placeholder , target_data ),
7104
- "b" , vocab_idx_repr (target_data .placeholder , target_data ),
7101
+ "a" , tf_util . vocab_idx_repr (source_data .placeholder , target_data ),
7102
+ "b" , tf_util . vocab_idx_repr (target_data .placeholder , target_data ),
7105
7103
"b_len" , target_data .get_sequence_lengths (),
7106
7104
"last_row" , self ._last_row , "next_row" , self ._next_row ]
7107
- self ._next_row = py_print (self ._next_row , print_out )
7105
+ self ._next_row = tf_util . py_print (self ._next_row , print_out )
7108
7106
self .rec_vars_outputs ["state" ] = self ._next_row
7109
7107
self ._reduce_out = None # see get_sub_layer
7110
7108
self .output .placeholder = self ._next_row
@@ -7126,12 +7124,11 @@ def get_rec_initial_extra_outputs(cls, batch_dim, rec_layer, sources, name, targ
7126
7124
if source_data .time_dim_axis is not None :
7127
7125
return {}
7128
7126
# expects inside rec layer
7129
- from returnn .tf .util .basic import expand_dims_unbroadcast
7130
7127
assert target , "%s %r: 'target' must be set" % (cls .__name__ , name )
7131
7128
target_data = cls ._static_get_target_value (target = target , network = network )
7132
7129
assert target_data , "target %r not found?" % target
7133
7130
n_time = tf .shape (target_data .placeholder )[target_data .time_dim_axis ]
7134
- d = {"state" : expand_dims_unbroadcast (tf .range (n_time + 1 ), axis = 0 , dim = batch_dim )}
7131
+ d = {"state" : tf_util . expand_dims_unbroadcast (tf .range (n_time + 1 ), axis = 0 , dim = batch_dim )}
7135
7132
if kwargs .get ("blank_idx" , None ) is not None :
7136
7133
d ["source_len" ] = tf .zeros ((batch_dim ,), dtype = tf .int32 )
7137
7134
return d
@@ -7155,33 +7152,44 @@ def transform_config_dict(cls, d, network, get_layer):
7155
7152
super (EditDistanceTableLayer , cls ).transform_config_dict (d , network = network , get_layer = get_layer )
7156
7153
7157
7154
@classmethod
7158
- def get_out_data_from_opts (cls , name , sources , target , network , _target_layers = None , blank_idx = None , ** kwargs ):
7155
+ def get_out_data_from_opts (cls , name , sources , target , network , _target_layers = None , blank_idx = None ,
7156
+ out_dim = None , ** kwargs ):
7159
7157
"""
7160
7158
:param str name:
7161
7159
:param list[LayerBase] sources:
7160
+ :param returnn.tf.network.TFNetwork network:
7162
7161
:param str target:
7163
7162
:param dict[str,LayerBase] _target_layers:
7164
7163
:param int|None blank_idx:
7165
- :param returnn.tf.network.TFNetwork network :
7164
+ :param DimensionTag|None out_dim :
7166
7165
:rtype: Data
7167
7166
"""
7168
7167
assert len (sources ) == 1 , "%s %r: expects exactly a single source" % (cls .__name__ , name )
7169
7168
source_data = sources [0 ].output
7170
7169
assert target , "%s %r: 'target' must be set" % (cls .__name__ , name )
7171
7170
target_data = cls ._static_get_target_value (target = target , _target_layers = _target_layers , network = network )
7172
7171
assert target_data , "target %r not found?" % target
7172
+ in_dim = target_data .get_time_dim_tag ()
7173
+ if not out_dim :
7174
+ out_dim = DimensionTag (
7175
+ kind = in_dim .kind , description = "%s:edit_dist_table" % name ,
7176
+ dimension = in_dim .dimension + 1 if in_dim .dimension else None ,
7177
+ batch = in_dim .batch , control_flow_ctx = in_dim .control_flow_ctx )
7173
7178
seq_len = tf_util .new_seq_len (
7174
7179
func = tf_util .simplify_add , key = tf_util .simplify_add ,
7175
- dim_tag_desc = "edit_dist_table:%s " % name ,
7180
+ dim_tag_desc = "%s:edit_dist_table " % name ,
7176
7181
a = target_data .get_sequence_lengths (), b = 1 )
7177
7182
tag = DimensionTag .get_tag_from_size_tensor (seq_len )
7178
- assert tag
7183
+ if tag :
7184
+ tag .declare_same_as (out_dim )
7185
+ else :
7186
+ out_dim .dyn_size = seq_len
7179
7187
return Data (
7180
7188
name = "%s_output" % name ,
7181
7189
dim_tags = (
7182
- [source_data .get_batch_dim_tag (), source_data .get_time_dim_tag (), tag ]
7190
+ [source_data .get_batch_dim_tag (), source_data .get_time_dim_tag (), out_dim ]
7183
7191
if source_data .have_time_axis () else
7184
- [source_data .get_batch_dim_tag (), tag ]),
7192
+ [source_data .get_batch_dim_tag (), out_dim ]),
7185
7193
dtype = "int32" , beam = SearchBeam .get_combined_beam (source_data .beam , target_data .beam ))
7186
7194
7187
7195
0 commit comments