@@ -8513,3 +8513,178 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
8513
8513
kind = DimensionTag .Types .Spatial , description = "%s_rel_pos_enc_time" % name , dimension = None )
8514
8514
data = data .copy_template_new_dim_tags ((dummy_dim_tag , time_dim_tag , feature_dim_tag ))
8515
8515
return data
8516
+
8517
+
8518
+ class CumConcatLayer (_ConcatInputLayer ):
8519
+ """
8520
+ Concatenates all previous frames of a time-axis.
8521
+ Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.
8522
+
8523
+ This layer can be used as a base for auto-regressive self-attention.
8524
+
8525
+ This layer expects to be inside a :class:`RecLayer`.
8526
+
8527
+ Inside a rec loop (not optimized out),
8528
+ this will concatenate the current input
8529
+ to the previous accumulated inputs.
8530
+ For an input of shape `input_shape`,
8531
+ it will output a tensor of shape `[new_dim] + input_shape`.
8532
+ `new_dim` is a special dimension, usually of length `i`,
8533
+ where `i` is the current loop frame,
8534
+ i.e. the length increases in every loop frame.
8535
+ `new_dim` is specified by a separate own dim tag.
8536
+ For example, in the first frame,
8537
+ this will be of shape `[1] + input_shape`,
8538
+ in the second frame shape `[2] + input_shape`,
8539
+ and so on,
8540
+ and in the last frame shape `[T] + input_shape`.
8541
+
8542
+ Outside the rec loop (optimized out),
8543
+ this layer expects an input with the time dim of the rec layer,
8544
+ and returns the input as-is,
8545
+ but replacing the time dim tag with the dim tag `new_dim`
8546
+ converted as outside the loop.
8547
+
8548
+ Normally the optimization should not matter for the user,
8549
+ i.e. for the user, the logical behavior is always as being inside the rec loop.
8550
+ Outside the loop,
8551
+ the output represents a tensor of shape `[T, new_dim] + input_shape`,
8552
+ although we actually have another `new_dim` outside the loop,
8553
+ and `T` is not actually there,
8554
+ but we still have all the information,
8555
+ because the last frame has all information.
8556
+ This `new_dim` outside the loop stores all the dynamic seq lengths
8557
+ per frame of the loop, i.e. the dyn seq len are extended of shape [B,T] or [T]
8558
+ (unlike usually just [B]).
8559
+ This way following layers use different seq lengths of `new_dim` for different loop frames,
8560
+ just like if the `T` dim would actually exist.
8561
+ """
8562
+ layer_class = "cum_concat"
8563
+ recurrent = True # order matters
8564
+
8565
+ def __init__ (self , new_dim , ** kwargs ):
8566
+ """
8567
+ :param DimensionTag new_dim:
8568
+ """
8569
+ super (CumConcatLayer , self ).__init__ (** kwargs )
8570
+ rec_layer = self .network .get_rec_parent_layer (inside_loop = False )
8571
+ assert rec_layer , "%r must be used inside a RecLayer" % self
8572
+ out_axis = self .output .get_axis_from_description (new_dim )
8573
+ new_dim_ = self .output .dim_tags [out_axis ]
8574
+
8575
+ if not self .input_data .has_axis (rec_layer .time_dim_tag ): # inside loop
8576
+ current_data = self .input_data .copy_compatible_to (self .output , unbroadcast = False )
8577
+ current_frame = current_data .placeholder # [B, 1, ..., D]
8578
+ last_frames = self ._rec_previous_layer .rec_vars_outputs ["state" ] # [B, t, ..., D]
8579
+ concat_frames = tf .concat ([last_frames , current_frame ], axis = out_axis ) # [B, t+1, ..., D]
8580
+ self .rec_vars_outputs ["state" ] = concat_frames
8581
+ self .output .placeholder = concat_frames
8582
+
8583
+ if not new_dim_ .dyn_size_ext :
8584
+ # Unbroadcasting to [B] is not needed because any layers operating on this
8585
+ # should be able to handle extended dyn sizes.
8586
+ # Clipping it to the max length for sequences in the loop which are already ended
8587
+ # (i.e. considering the end flag)
8588
+ # is also not needed because any calculations after the end are irrelevant.
8589
+ # Note: In case we have some initial state/output, this can be extended.
8590
+ dyn_size = self .network .get_rec_step_index () + 1 # scalar
8591
+ new_dim_ .dyn_size_ext = Data (
8592
+ name = "%s:cum-concat:size-inside" % self .name ,
8593
+ dim_tags = [], # scalar
8594
+ placeholder = dyn_size , dtype = "int32" )
8595
+
8596
+ else : # outside loop
8597
+ # If not inside a rec loop, this layer is a no-op on the tensor.
8598
+ self .output .placeholder = self .input_data .placeholder
8599
+
8600
+ # However, we used new dim tags, which were already prepared.
8601
+ # We now must fill in the extended dynamic size information.
8602
+ if not new_dim_ .dyn_size_ext :
8603
+ # This must match the logic above for inside the loop.
8604
+ # Note: In case we have some initial state/output, this can be extended.
8605
+ dyn_size = tf .range (tf .math .reduce_max (rec_layer .time_dim_tag .dyn_size )) + 1 # [T]
8606
+ new_dim_ .dyn_size_ext = Data (
8607
+ name = "%s:cum-concat:size-outside" % self .name ,
8608
+ dim_tags = [rec_layer .time_dim_tag ],
8609
+ placeholder = dyn_size , dtype = "int32" )
8610
+
8611
+ @classmethod
8612
+ def get_out_data_from_opts (cls , name , network , sources , new_dim , ** kwargs ):
8613
+ """
8614
+ :param str name:
8615
+ :param returnn.tf.network.TFNetwork network:
8616
+ :param list[LayerBase] sources:
8617
+ :param DimensionTag new_dim:
8618
+ :rtype: Data
8619
+ """
8620
+ assert network .is_inside_rec_layer (inside_loop = False ), "CumConcatLayer %r must be used inside a RecLayer" % name
8621
+ rec_time_dim = network .get_inside_rec_time_dim (inside_loop = False )
8622
+ assert rec_time_dim
8623
+ new_dim_base = new_dim .get_same_base ()
8624
+ if new_dim_base .per_spatial_frame is None :
8625
+ new_dim_base .per_spatial_frame = rec_time_dim
8626
+ else :
8627
+ assert new_dim_base .per_spatial_frame == rec_time_dim
8628
+
8629
+ input_data = get_concat_sources_data_template (sources , name = "%s_output" % name )
8630
+ if not input_data .has_axis (rec_time_dim ): # inside loop
8631
+ # Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
8632
+ # Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
8633
+ # In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
8634
+ # which should be more efficient
8635
+ out = input_data .copy_as_batch_major ()
8636
+ out = out .copy_add_dim_by_tag (new_dim_base , unbroadcast = True , axis = 1 )
8637
+ return out
8638
+
8639
+ else : # outside loop
8640
+ if not new_dim_base .per_spatial_frame_accumulated :
8641
+ new_dim_accum = DimensionTag (
8642
+ kind = new_dim_base .kind , description = "%s:accumulated" % name )
8643
+ new_dim_accum .declare_same_as (new_dim_base )
8644
+ new_dim_base .per_spatial_frame_accumulated = new_dim_accum
8645
+ else :
8646
+ new_dim_accum = new_dim_base .per_spatial_frame_accumulated
8647
+ # Assume that the input has the time dim from the rec layer.
8648
+ axis = input_data .get_axis_from_description (rec_time_dim )
8649
+ return input_data .copy_template_replace_dim_tag (axis = axis , new_dim_tag = new_dim_accum )
8650
+
8651
+ # noinspection PyMethodOverriding
8652
+ @classmethod
8653
+ def get_rec_initial_extra_outputs (cls , network , batch_dim , rec_layer , sources , output , new_dim , ** kwargs ):
8654
+ """
8655
+ :param returnn.tf.network.TFNetwork network:
8656
+ :param tf.Tensor batch_dim:
8657
+ :param returnn.tf.layers.rec.RecLayer|LayerBase rec_layer:
8658
+ :param list[LayerBase] sources:
8659
+ :param Data output:
8660
+ :param DimensionTag new_dim:
8661
+ :rtype: dict[str,tf.Tensor]
8662
+ """
8663
+ if network .is_inside_rec_layer ():
8664
+ shape = []
8665
+ for tag in output .dim_tags :
8666
+ if tag .is_batch_dim ():
8667
+ shape .append (batch_dim )
8668
+ elif tag == new_dim :
8669
+ shape .append (0 )
8670
+ elif tag .dimension is not None :
8671
+ shape .append (tag .dimension )
8672
+ else :
8673
+ assert tag .dyn_size is not None
8674
+ shape .append (tf .math .reduce_max (tag .dyn_size ))
8675
+ return {"state" : tf .zeros (shape , dtype = output .dtype )}
8676
+ else :
8677
+ return {}
8678
+
8679
+ @classmethod
8680
+ def get_rec_initial_extra_outputs_shape_invariants (cls , network , sources , output , ** kwargs ):
8681
+ """
8682
+ :param returnn.tf.network.TFNetwork network:
8683
+ :param list[LayerBase] sources:
8684
+ :param Data output:
8685
+ :rtype: dict[str, tf.TensorShape]
8686
+ """
8687
+ if network .is_inside_rec_layer ():
8688
+ return {"state" : tf .TensorShape (output .batch_shape )}
8689
+ else :
8690
+ return {}
0 commit comments