@@ -3482,6 +3482,43 @@ def test_reclayer_optimize_out_dot():
3482
3482
rtol = 1e-3 )
3483
3483
3484
3484
3485
+ def test_reclayer_optimize_out_dot_consistent_axes ():
3486
+ # https://github.com/rwth-i6/returnn/issues/569
3487
+ # Used for multi-head dot-attention.
3488
+ n_heads = 4
3489
+ n_key = 5
3490
+ n_value = 7
3491
+ n_key_total = n_heads * n_key
3492
+ n_value_total = n_heads * n_value
3493
+ check_reclayer_optimize_out (
3494
+ {"class" : "linear" , "activation" : None , "from" : "att" },
3495
+ other_subnet_layers = {
3496
+ "s" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "data:source" ,
3497
+ "n_out" : n_key_total }, # (B, D) -- Q (query). D should be same as enc_ctx
3498
+ "att_query" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_key ), "from" : "s" }, # (B, H, D/H)
3499
+ # Here is the main test, the dot-layer:
3500
+ "energy" : {"class" : "dot" , "red1" : - 1 , "red2" : - 1 , "var1" : "T" , "var2" : "T" ,
3501
+ "from" : ["base:enc_ctx" , "att_query" ]},
3502
+ # energy inside the loop will be (B, H, enc-T, 1).
3503
+ # energy outside the loop will be (B, H, enc-T, dec-T). I.e. enc-T is still the first time axis.
3504
+ "att_weights" : {"class" : "softmax_over_spatial" , "from" : "energy" }, # (B, enc-T, H, 1)
3505
+ "att0" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:enc_value" }, # (B, H, V)
3506
+ "att" : {"class" : "merge_dims" , "axes" : "static" , "from" : "att0" }, # (B, H*V); Use "static" here.
3507
+ },
3508
+ shared_base_net = {
3509
+ "encoder" : {"class" : "copy" , "from" : "data" },
3510
+ "enc_ctx0" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "encoder" ,
3511
+ "n_out" : n_key_total }, # (B, enc-T, D)
3512
+ "enc_ctx" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_key ),
3513
+ "from" : "enc_ctx0" , "is_output_layer" : True }, # (B, enc-T, H, D/H)
3514
+ "enc_value0" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : "encoder" ,
3515
+ "n_out" : n_value_total },
3516
+ "enc_value" : {"class" : "split_dims" , "axis" : "F" , "dims" : (n_heads , n_value ),
3517
+ "from" : "enc_value0" , "is_output_layer" : True }, # (B, enc-T, H, D/H)
3518
+ },
3519
+ rtol = 1e-3 )
3520
+
3521
+
3485
3522
def test_reclayer_optimize_out_dot_kv_in_rec ():
3486
3523
# Same as test_reclayer_optimize_out_dot, but with the att key/value layers declared INSIDE the rec layer.
3487
3524
AttNumHeads = 4
0 commit comments