Skip to content

Commit 43ce5a9

Browse files
authored
EditDistanceTableLayer, handle out_dim (#807)
#597
1 parent 88156c8 commit 43ce5a9

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

returnn/tf/layers/rec.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7045,13 +7045,13 @@ class EditDistanceTableLayer(LayerBase):
70457045
layer_class = "edit_distance_table"
70467046
recurrent = True
70477047

7048-
def __init__(self, debug=False, blank_idx=None, **kwargs):
7048+
def __init__(self, debug=False, blank_idx=None, out_dim=None, **kwargs):
70497049
"""
70507050
:param bool debug:
70517051
:param int|None blank_idx: if given, will keep the same row for this source label
7052+
:param DimensionTag|None out_dim:
70527053
"""
7053-
from returnn.tf.util.basic import where_bc
7054-
super(EditDistanceTableLayer, self).__init__(**kwargs)
7054+
super(EditDistanceTableLayer, self).__init__(out_dim=out_dim, **kwargs)
70557055
assert len(self.sources) == 1, "%s: expects exactly a single source" % self
70567056
source_data = self.sources[0].output
70577057
assert source_data.dtype == "int32" and source_data.batch_ndim <= 2 and source_data.sparse
@@ -7076,8 +7076,7 @@ def __init__(self, debug=False, blank_idx=None, **kwargs):
70767076
mask_flag = rec_step_info.get_prev_end_flag(target_search_choices=self.get_search_choices())
70777077
source = source_data.placeholder
70787078
if blank_idx is None:
7079-
from returnn.tf.util.basic import expand_dims_unbroadcast
7080-
source_len = expand_dims_unbroadcast(rec_step_info.step, axis=0, dim=batch_dim)
7079+
source_len = tf_util.expand_dims_unbroadcast(rec_step_info.step, axis=0, dim=batch_dim)
70817080
else:
70827081
source_len = self._rec_previous_layer.rec_vars_outputs["source_len"]
70837082
mask_flag = tf.logical_or(mask_flag, tf.equal(source, blank_idx))
@@ -7088,9 +7087,8 @@ def __init__(self, debug=False, blank_idx=None, **kwargs):
70887087
a_ended=mask_flag,
70897088
b=target_data.placeholder, b_len=target_data.get_sequence_lengths())
70907089
if blank_idx is not None:
7091-
self.rec_vars_outputs["source_len"] = source_len + where_bc(mask_flag, 0, 1)
7090+
self.rec_vars_outputs["source_len"] = source_len + tf_util.where_bc(mask_flag, 0, 1)
70927091
if debug:
7093-
from returnn.tf.util.basic import py_print, vocab_idx_repr
70947092
print_out = [str(self)]
70957093
choice = self.get_search_choices()
70967094
if choice:
@@ -7100,11 +7098,11 @@ def __init__(self, debug=False, blank_idx=None, **kwargs):
71007098
print_out += [
71017099
"a_n", rec_step_info.step,
71027100
"a_ended", rec_step_info.get_prev_end_flag(target_search_choices=self.get_search_choices()),
7103-
"a", vocab_idx_repr(source_data.placeholder, target_data),
7104-
"b", vocab_idx_repr(target_data.placeholder, target_data),
7101+
"a", tf_util.vocab_idx_repr(source_data.placeholder, target_data),
7102+
"b", tf_util.vocab_idx_repr(target_data.placeholder, target_data),
71057103
"b_len", target_data.get_sequence_lengths(),
71067104
"last_row", self._last_row, "next_row", self._next_row]
7107-
self._next_row = py_print(self._next_row, print_out)
7105+
self._next_row = tf_util.py_print(self._next_row, print_out)
71087106
self.rec_vars_outputs["state"] = self._next_row
71097107
self._reduce_out = None # see get_sub_layer
71107108
self.output.placeholder = self._next_row
@@ -7126,12 +7124,11 @@ def get_rec_initial_extra_outputs(cls, batch_dim, rec_layer, sources, name, targ
71267124
if source_data.time_dim_axis is not None:
71277125
return {}
71287126
# expects inside rec layer
7129-
from returnn.tf.util.basic import expand_dims_unbroadcast
71307127
assert target, "%s %r: 'target' must be set" % (cls.__name__, name)
71317128
target_data = cls._static_get_target_value(target=target, network=network)
71327129
assert target_data, "target %r not found?" % target
71337130
n_time = tf.shape(target_data.placeholder)[target_data.time_dim_axis]
7134-
d = {"state": expand_dims_unbroadcast(tf.range(n_time + 1), axis=0, dim=batch_dim)}
7131+
d = {"state": tf_util.expand_dims_unbroadcast(tf.range(n_time + 1), axis=0, dim=batch_dim)}
71357132
if kwargs.get("blank_idx", None) is not None:
71367133
d["source_len"] = tf.zeros((batch_dim,), dtype=tf.int32)
71377134
return d
@@ -7155,33 +7152,44 @@ def transform_config_dict(cls, d, network, get_layer):
71557152
super(EditDistanceTableLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
71567153

71577154
@classmethod
7158-
def get_out_data_from_opts(cls, name, sources, target, network, _target_layers=None, blank_idx=None, **kwargs):
7155+
def get_out_data_from_opts(cls, name, sources, target, network, _target_layers=None, blank_idx=None,
7156+
out_dim=None, **kwargs):
71597157
"""
71607158
:param str name:
71617159
:param list[LayerBase] sources:
7160+
:param returnn.tf.network.TFNetwork network:
71627161
:param str target:
71637162
:param dict[str,LayerBase] _target_layers:
71647163
:param int|None blank_idx:
7165-
:param returnn.tf.network.TFNetwork network:
7164+
:param DimensionTag|None out_dim:
71667165
:rtype: Data
71677166
"""
71687167
assert len(sources) == 1, "%s %r: expects exactly a single source" % (cls.__name__, name)
71697168
source_data = sources[0].output
71707169
assert target, "%s %r: 'target' must be set" % (cls.__name__, name)
71717170
target_data = cls._static_get_target_value(target=target, _target_layers=_target_layers, network=network)
71727171
assert target_data, "target %r not found?" % target
7172+
in_dim = target_data.get_time_dim_tag()
7173+
if not out_dim:
7174+
out_dim = DimensionTag(
7175+
kind=in_dim.kind, description="%s:edit_dist_table" % name,
7176+
dimension=in_dim.dimension + 1 if in_dim.dimension else None,
7177+
batch=in_dim.batch, control_flow_ctx=in_dim.control_flow_ctx)
71737178
seq_len = tf_util.new_seq_len(
71747179
func=tf_util.simplify_add, key=tf_util.simplify_add,
7175-
dim_tag_desc="edit_dist_table:%s" % name,
7180+
dim_tag_desc="%s:edit_dist_table" % name,
71767181
a=target_data.get_sequence_lengths(), b=1)
71777182
tag = DimensionTag.get_tag_from_size_tensor(seq_len)
7178-
assert tag
7183+
if tag:
7184+
tag.declare_same_as(out_dim)
7185+
else:
7186+
out_dim.dyn_size = seq_len
71797187
return Data(
71807188
name="%s_output" % name,
71817189
dim_tags=(
7182-
[source_data.get_batch_dim_tag(), source_data.get_time_dim_tag(), tag]
7190+
[source_data.get_batch_dim_tag(), source_data.get_time_dim_tag(), out_dim]
71837191
if source_data.have_time_axis() else
7184-
[source_data.get_batch_dim_tag(), tag]),
7192+
[source_data.get_batch_dim_tag(), out_dim]),
71857193
dtype="int32", beam=SearchBeam.get_combined_beam(source_data.beam, target_data.beam))
71867194

71877195

0 commit comments

Comments
 (0)