Skip to content

Commit bd2771d

Browse files
committed
test_reclayer_optimize_out_dot_consistent_axes
Test for #569
1 parent 8591808 commit bd2771d

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3460,6 +3460,43 @@ def test_reclayer_optimize_out_dot():
34603460
rtol=1e-3)
34613461

34623462

3463+
def test_reclayer_optimize_out_dot_consistent_axes():
3464+
# https://github.com/rwth-i6/returnn/issues/569
3465+
# Used for multi-head dot-attention.
3466+
n_heads = 4
3467+
n_key = 5
3468+
n_value = 7
3469+
n_key_total = n_heads * n_key
3470+
n_value_total = n_heads * n_value
3471+
check_reclayer_optimize_out(
3472+
{"class": "linear", "activation": None, "from": "att"},
3473+
other_subnet_layers={
3474+
"s": {"class": "linear", "activation": None, "with_bias": False, "from": "data:source",
3475+
"n_out": n_key_total}, # (B, D) -- Q (query). D should be same as enc_ctx
3476+
"att_query": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key), "from": "s"}, # (B, H, D/H)
3477+
# Here is the main test, the dot-layer:
3478+
"energy": {"class": "dot", "red1": -1, "red2": -1, "var1": "T", "var2": "T",
3479+
"from": ["base:enc_ctx", "att_query"]},
3480+
# energy inside the loop will be (B, H, enc-T, 1).
3481+
# energy outside the loop will be (B, H, enc-T, dec-T). I.e. enc-T is still the first time axis.
3482+
"att_weights": {"class": "softmax_over_spatial", "from": "energy"}, # (B, enc-T, H, 1)
3483+
"att0": {"class": "generic_attention", "weights": "att_weights", "base": "base:enc_value"}, # (B, H, V)
3484+
"att": {"class": "merge_dims", "axes": "static", "from": "att0"}, # (B, H*V); Use "static" here.
3485+
},
3486+
shared_base_net={
3487+
"encoder": {"class": "copy", "from": "data"},
3488+
"enc_ctx0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder",
3489+
"n_out": n_key_total}, # (B, enc-T, D)
3490+
"enc_ctx": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key),
3491+
"from": "enc_ctx0", "is_output_layer": True}, # (B, enc-T, H, D/H)
3492+
"enc_value0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder",
3493+
"n_out": n_value_total},
3494+
"enc_value": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_value),
3495+
"from": "enc_value0", "is_output_layer": True}, # (B, enc-T, H, D/H)
3496+
},
3497+
rtol=1e-3)
3498+
3499+
34633500
def test_reclayer_optimize_out_dot_kv_in_rec():
34643501
# Same as test_reclayer_optimize_out_dot, but with the att key/value layers declared INSIDE the rec layer.
34653502
AttNumHeads = 4

0 commit comments

Comments
 (0)