@@ -3385,8 +3385,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
3385
3385
rec_layer_dict ["unit" ].update (other_subnet_layers )
3386
3386
config = Config ({
3387
3387
"debug_print_layer_output_template" : True ,
3388
- "num_inputs" : n_in ,
3389
- "num_outputs" : n_out
3388
+ "extern_data" : {"data" : {"dim" : n_in }},
3390
3389
})
3391
3390
from returnn .tf .layers .rec import _SubnetworkRecCell
3392
3391
with make_scope () as session :
@@ -3463,6 +3462,40 @@ def test_reclayer_optimize_out_selfatt_left():
3463
3462
"class" : "self_attention" , "attention_left_only" : True , "num_heads" : 2 , "total_key_dim" : 6 , "n_out" : 18 })
3464
3463
3465
3464
3465
+ def test_reclayer_optimize_out_cum_concat_gen_self_att ():
3466
+ new_dim = DimensionTag (kind = DimensionTag .Types .Spatial , description = "cum_concat_new_dim" )
3467
+ n_key = 5
3468
+ n_value = 7
3469
+ check_reclayer_optimize_out (
3470
+ {"class" : "linear" , "from" : "att" , "activation" : None },
3471
+ {
3472
+ # This is very much the vanilla self attention,
3473
+ # implemented via the new generic way.
3474
+ # See https://github.com/rwth-i6/returnn/issues/391 for a long discussion.
3475
+ # Commented shapes are always for the layers inside the loop (not optimized).
3476
+ "qkv" : {"class" : "linear" , "from" : "data:source" , "activation" : None , "n_out" : n_key * 2 + n_value }, # [B,2*K+V]
3477
+ "qkv_split" : {"class" : "split" , "from" : "qkv" , "size_splits" : [n_key , n_key , n_value ]},
3478
+ "q" : {"class" : "copy" , "from" : "qkv_split/0" }, # inside [B,K]. optimized out [T,B,K]
3479
+ "k" : {"class" : "copy" , "from" : "qkv_split/1" }, # inside [B,K]. optimized out [T,B,K]
3480
+ "v" : {"class" : "copy" , "from" : "qkv_split/2" }, # inside [B,V]. optimized out [T,B,V]
3481
+ # cum_concat here. Note that the optimized-out shape is not as you might expect [T,max(t),B,K],
3482
+ # but instead using the optimized format, with extended dyn size on the special dim tag,
3483
+ # i.e. [t*,B,K], representing [T,t*,B,K].
3484
+ "k_accum" : {"class" : "cum_concat" , "new_dim" : new_dim , "from" : "k" }, # inside [t,B,K]. opt out [t*,B,K]
3485
+ "v_accum" : {"class" : "cum_concat" , "new_dim" : new_dim , "from" : "v" }, # inside [t,B,V]. opt out [t*,B,K]
3486
+ "energy" : {
3487
+ "class" : "dot" , "from" : ["q" , "k_accum" ],
3488
+ "red1" : "static:-1" , "red2" : "static:-1" ,
3489
+ "var1" : None , "var2" : new_dim }, # inside [B,t]. optimized out [T,B,t*]
3490
+ "att_weights" : {
3491
+ "class" : "softmax_over_spatial" , "from" : "energy" , "axis" : new_dim }, # inside [B,t]. opt out [T,B,t*]
3492
+ "att" : {
3493
+ "class" : "dot" , "from" : ["att_weights" , "v_accum" ],
3494
+ "red1" : new_dim , "red2" : new_dim ,
3495
+ "var1" : None , "var2" : "static:-1" }, # inside [B,V]. opt out [T,B,V]
3496
+ })
3497
+
3498
+
3466
3499
def test_reclayer_optimize_out_dot ():
3467
3500
# Used for multi-head dot-attention.
3468
3501
AttNumHeads = 4
0 commit comments