Skip to content

Commit e99be05

Browse files
committed
test_reclayer_optimize_out_dot_consistent_axes
Test for #569
1 parent 842e0ca commit e99be05

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3482,6 +3482,43 @@ def test_reclayer_optimize_out_dot():
34823482
rtol=1e-3)
34833483

34843484

3485+
def test_reclayer_optimize_out_dot_consistent_axes():
3486+
# https://github.com/rwth-i6/returnn/issues/569
3487+
# Used for multi-head dot-attention.
3488+
n_heads = 4
3489+
n_key = 5
3490+
n_value = 7
3491+
n_key_total = n_heads * n_key
3492+
n_value_total = n_heads * n_value
3493+
check_reclayer_optimize_out(
3494+
{"class": "linear", "activation": None, "from": "att"},
3495+
other_subnet_layers={
3496+
"s": {"class": "linear", "activation": None, "with_bias": False, "from": "data:source",
3497+
"n_out": n_key_total}, # (B, D) -- Q (query). D should be same as enc_ctx
3498+
"att_query": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key), "from": "s"}, # (B, H, D/H)
3499+
# Here is the main test, the dot-layer:
3500+
"energy": {"class": "dot", "red1": -1, "red2": -1, "var1": "T", "var2": "T",
3501+
"from": ["base:enc_ctx", "att_query"]},
3502+
# energy inside the loop will be (B, H, enc-T, 1).
3503+
# energy outside the loop will be (B, H, enc-T, dec-T). I.e. enc-T is still the first time axis.
3504+
"att_weights": {"class": "softmax_over_spatial", "from": "energy"}, # (B, enc-T, H, 1)
3505+
"att0": {"class": "generic_attention", "weights": "att_weights", "base": "base:enc_value"}, # (B, H, V)
3506+
"att": {"class": "merge_dims", "axes": "static", "from": "att0"}, # (B, H*V); Use "static" here.
3507+
},
3508+
shared_base_net={
3509+
"encoder": {"class": "copy", "from": "data"},
3510+
"enc_ctx0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder",
3511+
"n_out": n_key_total}, # (B, enc-T, D)
3512+
"enc_ctx": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key),
3513+
"from": "enc_ctx0", "is_output_layer": True}, # (B, enc-T, H, D/H)
3514+
"enc_value0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder",
3515+
"n_out": n_value_total},
3516+
"enc_value": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_value),
3517+
"from": "enc_value0", "is_output_layer": True}, # (B, enc-T, H, D/H)
3518+
},
3519+
rtol=1e-3)
3520+
3521+
34853522
def test_reclayer_optimize_out_dot_kv_in_rec():
34863523
# Same as test_reclayer_optimize_out_dot, but with the att key/value layers declared INSIDE the rec layer.
34873524
AttNumHeads = 4

0 commit comments

Comments
 (0)