@@ -3460,6 +3460,43 @@ def test_reclayer_optimize_out_dot():
3460
3460
rtol = 1e-3 )
3461
3461
3462
3462
3463
+ def test_reclayer_optimize_out_dot_consistent_axes ():
3464
+ # https://github.com/rwth-i6/returnn/issues/569
3465
+ # Used for multi-head dot-attention.
3466
+ n_heads = 4
3467
+ n_key = 5
3468
+ n_value = 7
3469
+ n_key_total = n_heads * n_key
3470
+ n_value_total = n_heads * n_value
3471
+ check_reclayer_optimize_out (
3472
+ {"class" : "linear" , "activation" : None , "from" : "att" },
3473
+ other_subnet_layers = {
3474
+ "s" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "data:source" ,
3475
+ "n_out" : n_key_total }, # (B, D) -- Q (query). D should be same as enc_ctx
3476
+ "att_query" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_key ), "from" : "s" }, # (B, H, D/H)
3477
+ # Here is the main test, the dot-layer:
3478
+ "energy" : {"class" : "dot" , "red1" : - 1 , "red2" : - 1 , "var1" : "T" , "var2" : "T" ,
3479
+ "from" : ["base:enc_ctx" , "att_query" ]},
3480
+ # energy inside the loop will be (B, H, enc-T, 1).
3481
+ # energy outside the loop will be (B, H, enc-T, dec-T). I.e. enc-T is still the first time axis.
3482
+ "att_weights" : {"class" : "softmax_over_spatial" , "from" : "energy" }, # (B, enc-T, H, 1)
3483
+ "att0" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:enc_value" }, # (B, H, V)
3484
+ "att" : {"class" : "merge_dims" , "axes" : "static" , "from" : "att0" }, # (B, H*V); Use "static" here.
3485
+ },
3486
+ shared_base_net = {
3487
+ "encoder" : {"class" : "copy" , "from" : "data" },
3488
+ "enc_ctx0" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "encoder" ,
3489
+ "n_out" : n_key_total }, # (B, enc-T, D)
3490
+ "enc_ctx" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_key ),
3491
+ "from" : "enc_ctx0" , "is_output_layer" : True }, # (B, enc-T, H, D/H)
3492
+ "enc_value0" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "encoder" ,
3493
+ "n_out" : n_value_total },
3494
+ "enc_value" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_value ),
3495
+ "from" : "enc_value0" , "is_output_layer" : True }, # (B, enc-T, H, D/H)
3496
+ },
3497
+ rtol = 1e-3 )
3498
+
3499
+
3463
3500
def test_reclayer_optimize_out_dot_kv_in_rec ():
3464
3501
# Same as test_reclayer_optimize_out_dot, but with the att key/value layers declared INSIDE the rec layer.
3465
3502
AttNumHeads = 4
0 commit comments