Skip to content

CumConcatLayer #589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8531,3 +8531,173 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
kind=DimensionTag.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None)
data = data.copy_template_new_dim_tags((dummy_dim_tag, time_dim_tag, feature_dim_tag))
return data


class CumConcatLayer(_ConcatInputLayer):
"""
Concatenates all previous frames of a time-axis.
Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.

This layer can be used as a base for auto-regressive self-attention.

This layer expects to be inside a :class:`RecLayer`.

Inside a rec loop (not optimized out),
this will concatenate the current input
to the previous accumulated inputs.
For an input of shape `input_shape`,
it will output a tensor of shape `[new_dim] + input_shape`.
`new_dim` is a special dimension, usually of length `i`,
where `i` is the current loop frame,
i.e. the length increases in every loop frame.
`new_dim` is specified by a separate own dim tag.
For example, in the first frame,
this will be of shape `[1] + input_shape`,
in the second frame shape `[2] + input_shape`,
and so on,
and in the last frame shape `[T] + input_shape`.

Outside the rec loop (optimized out),
this layer expects an input with the time dim of the rec layer,
and returns the input as-is,
but replacing the time dim tag with the dim tag `new_dim`
converted as outside the loop.

Normally the optimization should not matter for the user,
i.e. for the user, the logical behavior is always as being inside the rec loop.
Outside the loop,
the output represents a tensor of shape `[T, new_dim] + input_shape`,
although we actually have another `new_dim` outside the loop,
and `T` is not actually there,
but we still have all the information,
because the last frame has all information.
This `new_dim` outside the loop stores all the dynamic seq lengths
per frame of the loop, i.e. the dyn seq len are extended of shape [B,T] or [T]
(unlike usually just [B]).
This way following layers use different seq lengths of `new_dim` for different loop frames,
just like if the `T` dim would actually exist.
"""
layer_class = "cum_concat"
recurrent = True # order matters

def __init__(self, new_dim, **kwargs):
"""
:param DimensionTag new_dim:
"""
super(CumConcatLayer, self).__init__(**kwargs)
rec_layer = self.network.get_rec_parent_layer(inside_loop=False)
assert rec_layer, "%r must be used inside a RecLayer" % self
out_axis = self.output.get_axis_from_description(new_dim)
new_dim_ = self.output.dim_tags[out_axis]
assert new_dim_.control_flow_ctx == self.output.control_flow_ctx == self.network.get_control_flow_ctx()

if not self.input_data.has_axis(rec_layer.time_dim_tag): # inside loop
current_data = self.input_data.copy_compatible_to(self.output, unbroadcast=False)
current_frame = current_data.placeholder # [B, 1, ..., D]
last_frames = self._rec_previous_layer.rec_vars_outputs["state"] # [B, t, ..., D]
concat_frames = tf.concat([last_frames, current_frame], axis=out_axis) # [B, t+1, ..., D]
self.rec_vars_outputs["state"] = concat_frames
self.output.placeholder = concat_frames

if not new_dim_.dyn_size_ext:
# Unbroadcasting to [B] is not needed because any layers operating on this
# should be able to handle extended dyn sizes.
# Clipping it to the max length for sequences in the loop which are already ended
# (i.e. considering the end flag)
# is also not needed because any calculations after the end are irrelevant.
# Note: In case we have some initial state/output, this can be extended.
dyn_size = self.network.get_rec_step_index() + 1 # scalar
new_dim_.dyn_size_ext = Data(
name="%s:cum-concat:size-inside" % self.name,
dim_tags=[], # scalar
placeholder=dyn_size, dtype="int32",
batch=self.output.batch, control_flow_ctx=self.network.get_control_flow_ctx())

else: # outside loop
# If not inside a rec loop, this layer is a no-op on the tensor.
self.output.placeholder = self.input_data.placeholder

# However, we used new dim tags, which were already prepared.
# We now must fill in the extended dynamic size information.
if not new_dim_.dyn_size_ext:
# This must match the logic above for inside the loop.
# Note: In case we have some initial state/output, this can be extended.
dyn_size = tf.range(tf.math.reduce_max(rec_layer.time_dim_tag.dyn_size)) + 1 # [T]
new_dim_.dyn_size_ext = Data(
name="%s:cum-concat:size-outside" % self.name,
dim_tags=[rec_layer.time_dim_tag],
placeholder=dyn_size, dtype="int32",
batch=self.output.batch, control_flow_ctx=self.network.get_control_flow_ctx())

@classmethod
def get_out_data_from_opts(cls, name, network, sources, new_dim, **kwargs):
"""
:param str name:
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:param DimensionTag new_dim:
:rtype: Data
"""
input_data = get_concat_sources_data_template(sources, name="%s_output" % name)
assert network.is_inside_rec_layer(inside_loop=False), "CumConcatLayer %r must be used inside a RecLayer" % name
rec_time_dim = network.get_inside_rec_time_dim(inside_loop=False)
assert rec_time_dim
ctx = network.get_control_flow_ctx()
assert ctx == input_data.control_flow_ctx
new_dim_in_ctx = new_dim.get_for_batch_ctx(batch=input_data.batch, ctx=ctx)

if not input_data.has_axis(rec_time_dim): # inside loop
assert ctx and ctx.is_loop() and ctx.loop_spatial_dim == rec_time_dim
# Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
# Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
# In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
# which should be more efficient
out = input_data.copy_as_batch_major()
out = out.copy_add_dim_by_tag(new_dim_in_ctx, unbroadcast=True, axis=1)
return out

else: # outside loop
# Assume that the input has the time dim from the rec layer.
axis = input_data.get_axis_from_description(rec_time_dim)
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=new_dim_in_ctx)

# noinspection PyMethodOverriding
@classmethod
def get_rec_initial_extra_outputs(cls, network, batch_dim, rec_layer, sources, output, new_dim, **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param tf.Tensor batch_dim:
:param returnn.tf.layers.rec.RecLayer|LayerBase rec_layer:
:param list[LayerBase] sources:
:param Data output:
:param DimensionTag new_dim:
:rtype: dict[str,tf.Tensor]
"""
if network.is_inside_rec_layer():
shape = []
for tag in output.dim_tags:
if tag.is_batch_dim():
shape.append(batch_dim)
elif tag == new_dim:
shape.append(0)
elif tag.dimension is not None:
shape.append(tag.dimension)
else:
assert tag.dyn_size is not None
shape.append(tf.math.reduce_max(tag.dyn_size))
return {"state": tf.zeros(shape, dtype=output.dtype)}
else:
return {}

@classmethod
def get_rec_initial_extra_outputs_shape_invariants(cls, network, sources, output, **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:param Data output:
:rtype: dict[str, tf.TensorShape]
"""
if network.is_inside_rec_layer():
return {"state": tf.TensorShape(output.batch_shape)}
else:
return {}
37 changes: 35 additions & 2 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3385,8 +3385,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
rec_layer_dict["unit"].update(other_subnet_layers)
config = Config({
"debug_print_layer_output_template": True,
"num_inputs": n_in,
"num_outputs": n_out
"extern_data": {"data": {"dim": n_in}},
})
from returnn.tf.layers.rec import _SubnetworkRecCell
with make_scope() as session:
Expand Down Expand Up @@ -3463,6 +3462,40 @@ def test_reclayer_optimize_out_selfatt_left():
"class": "self_attention", "attention_left_only": True, "num_heads": 2, "total_key_dim": 6, "n_out": 18})


def test_reclayer_optimize_out_cum_concat_gen_self_att():
new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="cum_concat_new_dim")
n_key = 5
n_value = 7
check_reclayer_optimize_out(
{"class": "linear", "from": "att", "activation": None},
{
# This is very much the vanilla self attention,
# implemented via the new generic way.
# See https://github.com/rwth-i6/returnn/issues/391 for a long discussion.
# Commented shapes are always for the layers inside the loop (not optimized).
"qkv": {"class": "linear", "from": "data:source", "activation": None, "n_out": n_key * 2 + n_value}, # [B,2*K+V]
"qkv_split": {"class": "split", "from": "qkv", "size_splits": [n_key, n_key, n_value]},
"q": {"class": "copy", "from": "qkv_split/0"}, # inside [B,K]. optimized out [T,B,K]
"k": {"class": "copy", "from": "qkv_split/1"}, # inside [B,K]. optimized out [T,B,K]
"v": {"class": "copy", "from": "qkv_split/2"}, # inside [B,V]. optimized out [T,B,V]
# cum_concat here. Note that the optimized-out shape is not as you might expect [T,max(t),B,K],
# but instead using the optimized format, with extended dyn size on the special dim tag,
# i.e. [t*,B,K], representing [T,t*,B,K].
"k_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "k"}, # inside [t,B,K]. opt out [t*,B,K]
"v_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "v"}, # inside [t,B,V]. opt out [t*,B,K]
"energy": {
"class": "dot", "from": ["q", "k_accum"],
"red1": "static:-1", "red2": "static:-1",
"var1": None, "var2": new_dim}, # inside [B,t]. optimized out [T,B,t*]
"att_weights": {
"class": "softmax_over_spatial", "from": "energy", "axis": new_dim}, # inside [B,t]. opt out [T,B,t*]
"att": {
"class": "dot", "from": ["att_weights", "v_accum"],
"red1": new_dim, "red2": new_dim,
"var1": None, "var2": "static:-1"}, # inside [B,V]. opt out [T,B,V]
})


def test_reclayer_optimize_out_dot():
# Used for multi-head dot-attention.
AttNumHeads = 4
Expand Down