Skip to content

Commit 09c4cb8

Browse files
albertzZettelkasten
andcommitted
CumConcatLayer
This is for generalized self attention (#391). Co-authored-by: Frithjof <[email protected]>
1 parent bd2771d commit 09c4cb8

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

returnn/tf/layers/rec.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8498,3 +8498,178 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
84988498
kind=DimensionTag.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None)
84998499
data = data.copy_template_new_dim_tags((dummy_dim_tag, time_dim_tag, feature_dim_tag))
85008500
return data
8501+
8502+
8503+
class CumConcatLayer(_ConcatInputLayer):
8504+
"""
8505+
Concatenates all previous frames of a time-axis.
8506+
Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.
8507+
8508+
This layer can be used as a base for auto-regressive self-attention.
8509+
8510+
This layer expects to be inside a :class:`RecLayer`.
8511+
8512+
Inside a rec loop (not optimized out),
8513+
this will concatenate the current input
8514+
to the previous accumulated inputs.
8515+
For an input of shape `input_shape`,
8516+
it will output a tensor of shape `[new_dim] + input_shape`.
8517+
`new_dim` is a special dimension, usually of length `i`,
8518+
where `i` is the current loop frame,
8519+
i.e. the length increases in every loop frame.
8520+
`new_dim` is specified by a separate own dim tag.
8521+
For example, in the first frame,
8522+
this will be of shape `[1] + input_shape`,
8523+
in the second frame shape `[2] + input_shape`,
8524+
and so on,
8525+
and in the last frame shape `[T] + input_shape`.
8526+
8527+
Outside the rec loop (optimized out),
8528+
this layer expects an input with the time dim of the rec layer,
8529+
and returns the input as-is,
8530+
but replacing the time dim tag with the dim tag `new_dim`
8531+
converted as outside the loop.
8532+
8533+
Normally the optimization should not matter for the user,
8534+
i.e. for the user, the logical behavior is always as being inside the rec loop.
8535+
Outside the loop,
8536+
the output represents a tensor of shape `[T, new_dim] + input_shape`,
8537+
although we actually have another `new_dim` outside the loop,
8538+
and `T` is not actually there,
8539+
but we still have all the information,
8540+
because the last frame has all information.
8541+
This `new_dim` outside the loop stores all the dynamic seq lengths
8542+
per frame of the loop, i.e. the dyn seq len are extended of shape [B,T] or [T]
8543+
(unlike usually just [B]).
8544+
This way following layers use different seq lengths of `new_dim` for different loop frames,
8545+
just like if the `T` dim would actually exist.
8546+
"""
8547+
layer_class = "cum_concat"
8548+
recurrent = True # order matters
8549+
8550+
def __init__(self, new_dim, **kwargs):
8551+
"""
8552+
:param DimensionTag new_dim:
8553+
"""
8554+
super(CumConcatLayer, self).__init__(**kwargs)
8555+
rec_layer = self.network.get_rec_parent_layer(inside_loop=False)
8556+
assert rec_layer, "%r must be used inside a RecLayer" % self
8557+
out_axis = self.output.get_axis_from_description(new_dim)
8558+
new_dim_ = self.output.dim_tags[out_axis]
8559+
8560+
if not self.input_data.has_axis(rec_layer.time_dim_tag): # inside loop
8561+
current_data = self.input_data.copy_compatible_to(self.output, unbroadcast=False)
8562+
current_frame = current_data.placeholder # [B, 1, ..., D]
8563+
last_frames = self._rec_previous_layer.rec_vars_outputs["state"] # [B, t, ..., D]
8564+
concat_frames = tf.concat([last_frames, current_frame], axis=out_axis) # [B, t+1, ..., D]
8565+
self.rec_vars_outputs["state"] = concat_frames
8566+
self.output.placeholder = concat_frames
8567+
8568+
if not new_dim_.dyn_size_ext:
8569+
# Unbroadcasting to [B] is not needed because any layers operating on this
8570+
# should be able to handle extended dyn sizes.
8571+
# Clipping it to the max length for sequences in the loop which are already ended
8572+
# (i.e. considering the end flag)
8573+
# is also not needed because any calculations after the end are irrelevant.
8574+
# Note: In case we have some initial state/output, this can be extended.
8575+
dyn_size = self.network.get_rec_step_index() + 1 # scalar
8576+
new_dim_.dyn_size_ext = Data(
8577+
name="%s:cum-concat:size-inside" % self.name,
8578+
dim_tags=[], # scalar
8579+
placeholder=dyn_size, dtype="int32")
8580+
8581+
else: # outside loop
8582+
# If not inside a rec loop, this layer is a no-op on the tensor.
8583+
self.output.placeholder = self.input_data.placeholder
8584+
8585+
# However, we used new dim tags, which were already prepared.
8586+
# We now must fill in the extended dynamic size information.
8587+
if not new_dim_.dyn_size_ext:
8588+
# This must match the logic above for inside the loop.
8589+
# Note: In case we have some initial state/output, this can be extended.
8590+
dyn_size = tf.range(tf.math.reduce_max(rec_layer.time_dim_tag.dyn_size)) + 1 # [T]
8591+
new_dim_.dyn_size_ext = Data(
8592+
name="%s:cum-concat:size-outside" % self.name,
8593+
dim_tags=[rec_layer.time_dim_tag],
8594+
placeholder=dyn_size, dtype="int32")
8595+
8596+
@classmethod
8597+
def get_out_data_from_opts(cls, name, network, sources, new_dim, **kwargs):
8598+
"""
8599+
:param str name:
8600+
:param returnn.tf.network.TFNetwork network:
8601+
:param list[LayerBase] sources:
8602+
:param DimensionTag new_dim:
8603+
:rtype: Data
8604+
"""
8605+
assert network.is_inside_rec_layer(inside_loop=False), "CumConcatLayer %r must be used inside a RecLayer" % name
8606+
rec_time_dim = network.get_inside_rec_time_dim(inside_loop=False)
8607+
assert rec_time_dim
8608+
new_dim_base = new_dim.get_same_base()
8609+
if new_dim_base.per_spatial_frame is None:
8610+
new_dim_base.per_spatial_frame = rec_time_dim
8611+
else:
8612+
assert new_dim_base.per_spatial_frame == rec_time_dim
8613+
8614+
input_data = get_concat_sources_data_template(sources, name="%s_output" % name)
8615+
if not input_data.has_axis(rec_time_dim): # inside loop
8616+
# Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
8617+
# Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
8618+
# In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
8619+
# which should be more efficient
8620+
out = input_data.copy_as_batch_major()
8621+
out = out.copy_add_dim_by_tag(new_dim_base, unbroadcast=True, axis=1)
8622+
return out
8623+
8624+
else: # outside loop
8625+
if not new_dim_base.per_spatial_frame_accumulated:
8626+
new_dim_accum = DimensionTag(
8627+
kind=new_dim_base.kind, description="%s:accumulated" % name)
8628+
new_dim_accum.declare_same_as(new_dim_base)
8629+
new_dim_base.per_spatial_frame_accumulated = new_dim_accum
8630+
else:
8631+
new_dim_accum = new_dim_base.per_spatial_frame_accumulated
8632+
# Assume that the input has the time dim from the rec layer.
8633+
axis = input_data.get_axis_from_description(rec_time_dim)
8634+
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=new_dim_accum)
8635+
8636+
# noinspection PyMethodOverriding
8637+
@classmethod
8638+
def get_rec_initial_extra_outputs(cls, network, batch_dim, rec_layer, sources, output, new_dim, **kwargs):
8639+
"""
8640+
:param returnn.tf.network.TFNetwork network:
8641+
:param tf.Tensor batch_dim:
8642+
:param returnn.tf.layers.rec.RecLayer|LayerBase rec_layer:
8643+
:param list[LayerBase] sources:
8644+
:param Data output:
8645+
:param DimensionTag new_dim:
8646+
:rtype: dict[str,tf.Tensor]
8647+
"""
8648+
if network.is_inside_rec_layer():
8649+
shape = []
8650+
for tag in output.dim_tags:
8651+
if tag.is_batch_dim():
8652+
shape.append(batch_dim)
8653+
elif tag == new_dim:
8654+
shape.append(0)
8655+
elif tag.dimension is not None:
8656+
shape.append(tag.dimension)
8657+
else:
8658+
assert tag.dyn_size is not None
8659+
shape.append(tf.math.reduce_max(tag.dyn_size))
8660+
return {"state": tf.zeros(shape, dtype=output.dtype)}
8661+
else:
8662+
return {}
8663+
8664+
@classmethod
8665+
def get_rec_initial_extra_outputs_shape_invariants(cls, network, sources, output, **kwargs):
8666+
"""
8667+
:param returnn.tf.network.TFNetwork network:
8668+
:param list[LayerBase] sources:
8669+
:param Data output:
8670+
:rtype: dict[str, tf.TensorShape]
8671+
"""
8672+
if network.is_inside_rec_layer():
8673+
return {"state": tf.TensorShape(output.batch_shape)}
8674+
else:
8675+
return {}

0 commit comments

Comments
 (0)