Skip to content
46 changes: 36 additions & 10 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ def __init__(self, axis=None, energy_factor=None,
By default, if dyn seq len exists, it uses it.
:param bool log_space: if True, returns in log space (i.e. uses log_softmax)
"""
from returnn.tf.util.basic import where_bc
from returnn.tf.util.basic import where_bc, set_padding_info
super(SoftmaxOverSpatialLayer, self).__init__(**kwargs)
self.start = start
self.window_start = window_start
Expand Down Expand Up @@ -1665,6 +1665,8 @@ def __init__(self, axis=None, energy_factor=None,
self.output_before_activation = OutputWithActivation(
energy, act_func=tf.nn.log_softmax if log_space else tf.nn.softmax) # (...,T)
self.output.placeholder = self.output_before_activation.y
if use_time_mask:
set_padding_info(self.output.placeholder, dim=self.output.dim_tags[axis], pad_value=0.)
# Never allow inf in output, as softmax should remove all -inf values used for masking
self.allow_inf_in_output = False

Expand Down Expand Up @@ -1753,6 +1755,7 @@ def __init__(self, mask_value, axis="T", seq_len_source=None,
:param LayerBase|None window_start: Tensor of shape (B,) indicating the window start
:param LayerBase|int|None window_size:
"""
from ..util.basic import set_padding_info
super(SeqLenMaskLayer, self).__init__(**kwargs)
self.seq_len_source = seq_len_source
self.start = start
Expand All @@ -1768,6 +1771,8 @@ def __init__(self, mask_value, axis="T", seq_len_source=None,
window_size=window_size.output if isinstance(window_size, LayerBase) else window_size)
from returnn.tf.util.basic import where_bc
x_ = where_bc(mask, self.input_data.placeholder, mask_value)
axis_ = self.input_data.get_axis_from_description(axis)
set_padding_info(x_, dim=self.input_data.dim_tags[axis_], pad_value=mask_value)
self.output.placeholder = x_
if mask_value in [float("-inf"), float("inf")]:
self.allow_inf_in_output = True
Expand Down Expand Up @@ -5370,7 +5375,7 @@ def __init__(self, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, d
:param bool add_var2_if_empty: if var2=None, add dim=1 at the end
:param bool debug: will print debug shapes, etc.
"""
from returnn.tf.util.basic import prod
from returnn.tf.util.basic import prod, get_shape, get_padding_info_dict_ref, mask_dyn_seq_len_nd
super(DotLayer, self).__init__(**kwargs)
a_out = self.sources[0].output.copy()
b_out = self.sources[1].output.copy()
Expand Down Expand Up @@ -5403,27 +5408,48 @@ def __init__(self, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, d
# So we reshape such that we collapse all reduce-axes and var-axes into each a single axis.
a = a_out.placeholder
b = b_out.placeholder
a_shape = tf.shape(a)
b_shape = tf.shape(b)
a_shape = [a_out.batch_shape[i] or a_shape[i] for i in range(a_out.batch_ndim)]
b_shape = [b_out.batch_shape[i] or b_shape[i] for i in range(b_out.batch_ndim)]
a_shape = get_shape(a)
b_shape = get_shape(b)
a_rem_dims = [a_shape[i] for i in a_rem_axes]
b_rem_dims = [b_shape[i] for i in b_rem_axes]
assert len(a_rem_axes) == len(b_rem_axes), "%s: remaining shared (batch) axes do not match. sources %r" % (
self, self.sources)
assert all([
isinstance(d1, tf.Tensor) or isinstance(d2, tf.Tensor) or d1 == d2
for (d1, d2) in zip(a_rem_dims, b_rem_dims)])
a_out.dim_tags[i1] == b_out.dim_tags[i2] or d1 == d2
for (d1, d2, i1, i2) in zip(a_rem_dims, b_rem_dims, a_rem_axes, b_rem_axes)])
a_var_dims = [a_shape[i] for i in a_var_axes]
b_var_dims = [b_shape[i] for i in b_var_axes]
a_reduce_dims = [a_shape[i] for i in a_reduce_axes]
b_reduce_dims = [b_shape[i] for i in b_reduce_axes]
assert len(a_reduce_axes) == len(b_reduce_axes)
assert all([
isinstance(d1, tf.Tensor) or isinstance(d2, tf.Tensor) or d1 == d2
for (d1, d2) in zip(a_reduce_dims, b_reduce_dims)])
a_out.dim_tags[i1] == b_out.dim_tags[i2] or d1 == d2
for (d1, d2, i1, i2) in zip(a_reduce_dims, b_reduce_dims, a_reduce_axes, b_reduce_axes)])
a_var_dim = prod(a_var_dims)
b_var_dim = prod(b_var_dims)
a_reduce_dyn_axes = [i for i in a_reduce_axes if a_out.batch_shape[i] is None]
b_reduce_dyn_axes = [i for i in b_reduce_axes if b_out.batch_shape[i] is None]
assert len(a_reduce_dyn_axes) == len(b_reduce_dyn_axes)
if a_reduce_dyn_axes:
a_pad, b_pad = get_padding_info_dict_ref(a), get_padding_info_dict_ref(b)
a_pad_values = [a_pad.get(a_out.dim_tags[i], None) for i in a_reduce_dyn_axes]
b_pad_values = [b_pad.get(b_out.dim_tags[i], None) for i in b_reduce_dyn_axes]
if set(a_pad_values) == {0}:
self._info_reduce_mask = "source-0-already-masked" # it's already masked as needed
elif set(b_pad_values) == {0}:
self._info_reduce_mask = "source-1-already-masked" # it's already masked as needed
else:
# We need to apply a mask.
# We don't need it on both a and b. We can either apply it on a or on b.
# Use some very simple heuristic where the mask is maybe cheaper.
if len(a_shape) < len(b_shape):
a = mask_dyn_seq_len_nd(a_out, pad_value=0, axes=a_reduce_dyn_axes)
self._info_reduce_mask = "mask-source-0"
else:
b = mask_dyn_seq_len_nd(b_out, pad_value=0, axes=b_reduce_dyn_axes)
self._info_reduce_mask = "mask-source-1"
else:
self._info_reduce_mask = "none-dynamic"
a_reduce_dim = prod(a_reduce_dims)
b_reduce_dim = prod(b_reduce_dims)
if debug:
Expand Down
65 changes: 65 additions & 0 deletions returnn/tf/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,71 @@ def copy_with_new_split_axes(old_axis_splits, new_axis_splits, old_values, new_v
return new_values


def get_padding_info_dict_ref(x):
"""
:param tf.Tensor x:
:rtype: dict[DimensionTag,float|int]
"""
_attr = "RETURNN_attr_padding_value_info"
if hasattr(x, _attr):
d = getattr(x, _attr)
assert isinstance(d, dict)
return d
d = {}
setattr(x, _attr, d)
return d


def set_padding_info(x, dim, pad_value):
"""
Stores the information what kind of padding value to expect after masking in the given dynamic dim.

:param tf.Tensor x:
:param returnn.tf.util.data.DimensionTag dim: dynamic seq len axis
:param float|int pad_value:
"""
d = get_padding_info_dict_ref(x)
# If there is some earlier padding info, only keep it when it is the same value.
# Otherwise it becomes invalid.
for k, v in list(d.items()):
if v != pad_value:
del d[k]
d[dim] = pad_value


def mask_dyn_seq_len_nd(x, pad_value, axes):
"""
:param Data x:
:param float|int pad_value:
:param list[int]|tuple[int] axes:
:return: masked x
:rtype: tf.Tensor
"""
x_ = x.placeholder
dim_tags = [x.dim_tags[i] for i in axes]
d = get_padding_info_dict_ref(x_)
existing_pad_values = [d.get(tag) for tag in dim_tags]
if set(existing_pad_values) == {pad_value}:
return x.placeholder # nothing to do

x_shape = get_shape(x_)
mask = tf.ones([1] * len(x_shape), dtype=tf.bool)
for axis in axes:
tag = x.dim_tags[axis]
idx_range = tf.range(x_shape[axis])
idx_range = tf.reshape(idx_range, [1] * (axis - 1) + x_shape[axis:axis + 1] + [1] * (len(x_shape) - axis - 1))
assert tag.dyn_size_ext
assert set(tag.dyn_size_ext.dim_tags).issubset(x.dim_tags)
size_ext = tag.dyn_size_ext.copy_compatible_to(x, check_dtype=False)
mask_ = tf.less(idx_range, size_ext.placeholder)
mask = tf.logical_and(mask, mask_)
x_ = where_bc(mask, x_, tf.cast(tf.constant(pad_value, name="pad_value"), dtype=x_.dtype))
d = get_padding_info_dict_ref(x_)
d.clear()
d.update({k: pad_value for k in dim_tags})
return x_


class OutputWithActivation(object):
"""
Stores some tensor before and after some activation function,
Expand Down
61 changes: 45 additions & 16 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,8 @@ def __init__(self, name,
if time_dim_axis is NotSpecified:
time_dim_axis = _default_time_dim_axis_dim_tags(dim_tags)
dim_tags = tuple(dim_tags)
if auto_create_placeholders:
_auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=dim_tags)
del shape_
del batch_dim_axis_
else:
Expand Down Expand Up @@ -3663,6 +3665,19 @@ def _batch_shape_from_shape(shape, batch_dim_axis):
return shape


def _create_size_placeholder(name, axis_wo_b, tag):
"""
:param str name:
:param int axis_wo_b:
:param DimensionTag tag:
"""
from .basic import reuse_name_scope
with reuse_name_scope("extern_data/placeholders/%s" % name, absolute=True):
dyn_size = tf_compat.v1.placeholder(
name="%s_dim%i_size" % (name, axis_wo_b), dtype=Data.size_dtype, shape=(None,))
tag.set_tag_on_size_tensor(dyn_size)


def _infer_dim_tags_tuple_from_shape(
shape,
batch_dim_axis, time_dim_axis, feature_dim_axis,
Expand All @@ -3685,7 +3700,6 @@ def _infer_dim_tags_tuple_from_shape(
:return: dim tags tuple
:rtype: tuple[DimensionTag]
"""
from .basic import reuse_name_scope
assert isinstance(shape, (tuple, list))
shape = tuple(shape)
batch_shape = _batch_shape_from_shape(shape, batch_dim_axis=batch_dim_axis)
Expand Down Expand Up @@ -3723,21 +3737,19 @@ def _infer_dim_tags_tuple_from_shape(
dyn_size = size_placeholder.get(axis_wo_b) if (size_placeholder and axis_wo_b is not None) else None
dim = batch_shape[axis]
if auto_create_placeholders and dim is None and dyn_size is None and axis != batch_dim_axis:
with reuse_name_scope("extern_data/placeholders/%s" % name, absolute=True):
dyn_size = tf_compat.v1.placeholder(
name="%s_dim%i_size" % (name, axis_wo_b), dtype=Data.size_dtype, shape=(None,))
if not tag:
if axis == time_dim_axis:
tag_name = "time"
else:
tag_name = "spatial%i" % axis
tag = DimensionTag(
description="%s:var:extern_data:%s" % (tag_name, name),
# Spatial dim tag, even if axis == feature_dim_axis. This is to keep the old behavior.
# This is such that DimensionTag.is_equal behaves as before, e.g. in Data.get_common_data.
kind=DimensionTag.Types.Spatial)
dim_tags[axis] = tag
tag.set_tag_on_size_tensor(dyn_size)
if not tag:
if axis == time_dim_axis:
tag_name = "time"
else:
tag_name = "spatial%i" % axis
tag = DimensionTag(
description="%s:var:extern_data:%s" % (tag_name, name),
# Spatial dim tag, even if axis == feature_dim_axis. This is to keep the old behavior.
# This is such that DimensionTag.is_equal behaves as before, e.g. in Data.get_common_data.
kind=DimensionTag.Types.Spatial)
dim_tags[axis] = tag
_create_size_placeholder(name=name, axis_wo_b=axis_wo_b, tag=tag)
dyn_size = tag.dyn_size
if tag:
# Just some sanity checks.
assert isinstance(tag, DimensionTag)
Expand Down Expand Up @@ -3765,6 +3777,23 @@ def _infer_dim_tags_tuple_from_shape(
return tuple(dim_tags[axis] for axis in range(len(batch_shape)))


def _auto_create_size_placeholders_on_dim_tags(name, dim_tags):
"""
:param str name:
:param tuple[DimensionTag] dim_tags:
"""
batch_dim_axis = _batch_dim_axis_from_dim_tags_tuple(dim_tags)
for axis, tag in enumerate(dim_tags):
if tag.is_batch_dim():
continue
if tag.dimension is not None:
continue
if tag.dyn_size is not None:
continue
axis_wo_b = _get_axis_wo_b(axis, batch_dim_axis=batch_dim_axis)
_create_size_placeholder(name=name, axis_wo_b=axis_wo_b, tag=tag)


def _get_axis_wo_b(axis_wb, batch_dim_axis, batch_ndim=None):
"""
:param int axis_wb: counted with batch-dim
Expand Down
68 changes: 66 additions & 2 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,7 @@ def test_DotLayer():
assert_equal(seq_lens.tolist(), a_seq_lens)
assert_equal(out.shape, (B, H, max(a_seq_lens), 1))


def test_DotLayer2():
""" Test if DotLayer can handle inputs which dont have a batch-dim
"""
Expand All @@ -3193,8 +3194,8 @@ def test_DotLayer2():
b = InternalLayer(name='B',
network=net,
output=Data(name="B", shape=(S1, S2, R, V), batch_dim_axis=None, time_dim_axis=None))
assert b.output.batch_dim_axis == None
assert b.output.time_dim_axis == None
assert b.output.batch_dim_axis is None
assert b.output.time_dim_axis is None
assert b.output.shape == (S1, S2, R, V)
assert b.output.dim == V
b.output.placeholder = tf.reshape(tf.range(S1 * S2 * R * V, dtype=tf.float32), (S1, S2, R, V))
Expand All @@ -3215,6 +3216,69 @@ def test_DotLayer2():
assert_equal(out.shape, (S1, S2, B, V))


def test_DotLayer_mask_dyn_seq():
batch = DimensionTag(kind=DimensionTag.Types.Batch, description="batch")
time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time")
feat1 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 1", dimension=3)
feat2 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 2", dimension=5)
config = Config({
"extern_data": {
"src1": {"dim_tags": [batch, time, feat1]},
"src2": {"dim_tags": [batch, time, feat2]},
},
"network": {
"dot": {
"class": "dot", "from": ["data:src1", "data:src2"], "is_output_layer": True,
"red1": time, "red2": time, "var1": feat1, "var2": feat2
},
},
"debug_print_layer_output_template": True,
})

with make_scope() as session:
net = TFNetwork(config=config)
net.construct_from_dict(config.typed_dict["network"])
layer = net.layers["dot"]
assert isinstance(layer, DotLayer)
assert layer.output.dim_tags == (batch, feat1, feat2)
assert layer._info_reduce_mask == "mask-source-1"

feed_dict = make_feed_dict(net.extern_data)
session.run(layer.output.placeholder, feed_dict=feed_dict)


def test_DotLayer_mask_dyn_seq_after_softmax():
batch = DimensionTag(kind=DimensionTag.Types.Batch, description="batch")
time = DimensionTag(kind=DimensionTag.Types.Spatial, description="time")
feat1 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 1", dimension=3)
feat2 = DimensionTag(kind=DimensionTag.Types.Feature, description="feature 2", dimension=5)
config = Config({
"extern_data": {
"src1": {"dim_tags": [batch, time, feat1]},
"src2": {"dim_tags": [batch, time, feat2]},
},
"network": {
"sm1": {"class": "softmax_over_spatial", "from": "data:src1"},
"dot": {
"class": "dot", "from": ["sm1", "data:src2"], "is_output_layer": True,
"red1": time, "red2": time, "var1": feat1, "var2": feat2
},
},
"debug_print_layer_output_template": True,
})

with make_scope() as session:
net = TFNetwork(config=config)
net.construct_from_dict(config.typed_dict["network"])
layer = net.layers["dot"]
assert isinstance(layer, DotLayer)
assert layer.output.dim_tags == (batch, feat1, feat2)
assert layer._info_reduce_mask == "source-0-already-masked"

feed_dict = make_feed_dict(net.extern_data)
session.run(layer.output.placeholder, feed_dict=feed_dict)


def test_subnet_load_on_init():
import tempfile
model_tmp_dir = tempfile.mkdtemp("tmp-checkpoint")
Expand Down
Loading