Skip to content

Commit 4661bbe

Browse files
albertzZettelkasten
andcommitted
CumConcatLayer (#589)
This is for generalized self attention (#391). Fixes #391. Co-authored-by: Frithjof <[email protected]>
1 parent 9baf23d commit 4661bbe

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

returnn/tf/layers/rec.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8531,3 +8531,173 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
85318531
kind=DimensionTag.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None)
85328532
data = data.copy_template_new_dim_tags((dummy_dim_tag, time_dim_tag, feature_dim_tag))
85338533
return data
8534+
8535+
8536+
class CumConcatLayer(_ConcatInputLayer):
8537+
"""
8538+
Concatenates all previous frames of a time-axis.
8539+
Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.
8540+
8541+
This layer can be used as a base for auto-regressive self-attention.
8542+
8543+
This layer expects to be inside a :class:`RecLayer`.
8544+
8545+
Inside a rec loop (not optimized out),
8546+
this will concatenate the current input
8547+
to the previous accumulated inputs.
8548+
For an input of shape `input_shape`,
8549+
it will output a tensor of shape `[new_dim] + input_shape`.
8550+
`new_dim` is a special dimension, usually of length `i`,
8551+
where `i` is the current loop frame,
8552+
i.e. the length increases in every loop frame.
8553+
`new_dim` is specified by a separate own dim tag.
8554+
For example, in the first frame,
8555+
this will be of shape `[1] + input_shape`,
8556+
in the second frame shape `[2] + input_shape`,
8557+
and so on,
8558+
and in the last frame shape `[T] + input_shape`.
8559+
8560+
Outside the rec loop (optimized out),
8561+
this layer expects an input with the time dim of the rec layer,
8562+
and returns the input as-is,
8563+
but replacing the time dim tag with the dim tag `new_dim`
8564+
converted as outside the loop.
8565+
8566+
Normally the optimization should not matter for the user,
8567+
i.e. for the user, the logical behavior is always as being inside the rec loop.
8568+
Outside the loop,
8569+
the output represents a tensor of shape `[T, new_dim] + input_shape`,
8570+
although we actually have another `new_dim` outside the loop,
8571+
and `T` is not actually there,
8572+
but we still have all the information,
8573+
because the last frame has all information.
8574+
This `new_dim` outside the loop stores all the dynamic seq lengths
8575+
per frame of the loop, i.e. the dyn seq len are extended of shape [B,T] or [T]
8576+
(unlike usually just [B]).
8577+
This way following layers use different seq lengths of `new_dim` for different loop frames,
8578+
just like if the `T` dim would actually exist.
8579+
"""
8580+
layer_class = "cum_concat"
8581+
recurrent = True # order matters
8582+
8583+
def __init__(self, new_dim, **kwargs):
8584+
"""
8585+
:param DimensionTag new_dim:
8586+
"""
8587+
super(CumConcatLayer, self).__init__(**kwargs)
8588+
rec_layer = self.network.get_rec_parent_layer(inside_loop=False)
8589+
assert rec_layer, "%r must be used inside a RecLayer" % self
8590+
out_axis = self.output.get_axis_from_description(new_dim)
8591+
new_dim_ = self.output.dim_tags[out_axis]
8592+
assert new_dim_.control_flow_ctx == self.output.control_flow_ctx == self.network.get_control_flow_ctx()
8593+
8594+
if not self.input_data.has_axis(rec_layer.time_dim_tag): # inside loop
8595+
current_data = self.input_data.copy_compatible_to(self.output, unbroadcast=False)
8596+
current_frame = current_data.placeholder # [B, 1, ..., D]
8597+
last_frames = self._rec_previous_layer.rec_vars_outputs["state"] # [B, t, ..., D]
8598+
concat_frames = tf.concat([last_frames, current_frame], axis=out_axis) # [B, t+1, ..., D]
8599+
self.rec_vars_outputs["state"] = concat_frames
8600+
self.output.placeholder = concat_frames
8601+
8602+
if not new_dim_.dyn_size_ext:
8603+
# Unbroadcasting to [B] is not needed because any layers operating on this
8604+
# should be able to handle extended dyn sizes.
8605+
# Clipping it to the max length for sequences in the loop which are already ended
8606+
# (i.e. considering the end flag)
8607+
# is also not needed because any calculations after the end are irrelevant.
8608+
# Note: In case we have some initial state/output, this can be extended.
8609+
dyn_size = self.network.get_rec_step_index() + 1 # scalar
8610+
new_dim_.dyn_size_ext = Data(
8611+
name="%s:cum-concat:size-inside" % self.name,
8612+
dim_tags=[], # scalar
8613+
placeholder=dyn_size, dtype="int32",
8614+
batch=self.output.batch, control_flow_ctx=self.network.get_control_flow_ctx())
8615+
8616+
else: # outside loop
8617+
# If not inside a rec loop, this layer is a no-op on the tensor.
8618+
self.output.placeholder = self.input_data.placeholder
8619+
8620+
# However, we used new dim tags, which were already prepared.
8621+
# We now must fill in the extended dynamic size information.
8622+
if not new_dim_.dyn_size_ext:
8623+
# This must match the logic above for inside the loop.
8624+
# Note: In case we have some initial state/output, this can be extended.
8625+
dyn_size = tf.range(tf.math.reduce_max(rec_layer.time_dim_tag.dyn_size)) + 1 # [T]
8626+
new_dim_.dyn_size_ext = Data(
8627+
name="%s:cum-concat:size-outside" % self.name,
8628+
dim_tags=[rec_layer.time_dim_tag],
8629+
placeholder=dyn_size, dtype="int32",
8630+
batch=self.output.batch, control_flow_ctx=self.network.get_control_flow_ctx())
8631+
8632+
@classmethod
8633+
def get_out_data_from_opts(cls, name, network, sources, new_dim, **kwargs):
8634+
"""
8635+
:param str name:
8636+
:param returnn.tf.network.TFNetwork network:
8637+
:param list[LayerBase] sources:
8638+
:param DimensionTag new_dim:
8639+
:rtype: Data
8640+
"""
8641+
input_data = get_concat_sources_data_template(sources, name="%s_output" % name)
8642+
assert network.is_inside_rec_layer(inside_loop=False), "CumConcatLayer %r must be used inside a RecLayer" % name
8643+
rec_time_dim = network.get_inside_rec_time_dim(inside_loop=False)
8644+
assert rec_time_dim
8645+
ctx = network.get_control_flow_ctx()
8646+
assert ctx == input_data.control_flow_ctx
8647+
new_dim_in_ctx = new_dim.get_for_batch_ctx(batch=input_data.batch, ctx=ctx)
8648+
8649+
if not input_data.has_axis(rec_time_dim): # inside loop
8650+
assert ctx and ctx.is_loop() and ctx.loop_spatial_dim == rec_time_dim
8651+
# Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
8652+
# Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
8653+
# In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
8654+
# which should be more efficient
8655+
out = input_data.copy_as_batch_major()
8656+
out = out.copy_add_dim_by_tag(new_dim_in_ctx, unbroadcast=True, axis=1)
8657+
return out
8658+
8659+
else: # outside loop
8660+
# Assume that the input has the time dim from the rec layer.
8661+
axis = input_data.get_axis_from_description(rec_time_dim)
8662+
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=new_dim_in_ctx)
8663+
8664+
# noinspection PyMethodOverriding
8665+
@classmethod
8666+
def get_rec_initial_extra_outputs(cls, network, batch_dim, rec_layer, sources, output, new_dim, **kwargs):
8667+
"""
8668+
:param returnn.tf.network.TFNetwork network:
8669+
:param tf.Tensor batch_dim:
8670+
:param returnn.tf.layers.rec.RecLayer|LayerBase rec_layer:
8671+
:param list[LayerBase] sources:
8672+
:param Data output:
8673+
:param DimensionTag new_dim:
8674+
:rtype: dict[str,tf.Tensor]
8675+
"""
8676+
if network.is_inside_rec_layer():
8677+
shape = []
8678+
for tag in output.dim_tags:
8679+
if tag.is_batch_dim():
8680+
shape.append(batch_dim)
8681+
elif tag == new_dim:
8682+
shape.append(0)
8683+
elif tag.dimension is not None:
8684+
shape.append(tag.dimension)
8685+
else:
8686+
assert tag.dyn_size is not None
8687+
shape.append(tf.math.reduce_max(tag.dyn_size))
8688+
return {"state": tf.zeros(shape, dtype=output.dtype)}
8689+
else:
8690+
return {}
8691+
8692+
@classmethod
8693+
def get_rec_initial_extra_outputs_shape_invariants(cls, network, sources, output, **kwargs):
8694+
"""
8695+
:param returnn.tf.network.TFNetwork network:
8696+
:param list[LayerBase] sources:
8697+
:param Data output:
8698+
:rtype: dict[str, tf.TensorShape]
8699+
"""
8700+
if network.is_inside_rec_layer():
8701+
return {"state": tf.TensorShape(output.batch_shape)}
8702+
else:
8703+
return {}

0 commit comments

Comments
 (0)