Skip to content

Commit f7f49d1

Browse files
committed
test_reclayer_optimize_out_cum_concat_gen_self_att
1 parent fe371ea commit f7f49d1

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3336,8 +3336,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33363336
rec_layer_dict["unit"].update(other_subnet_layers)
33373337
config = Config({
33383338
"debug_print_layer_output_template": True,
3339-
"num_inputs": n_in,
3340-
"num_outputs": n_out
3339+
"extern_data": {"data": {"dim": n_in}},
33413340
})
33423341
from returnn.tf.layers.rec import _SubnetworkRecCell
33433342
with make_scope() as session:
@@ -3414,6 +3413,38 @@ def test_reclayer_optimize_out_selfatt_left():
34143413
"class": "self_attention", "attention_left_only": True, "num_heads": 2, "total_key_dim": 6, "n_out": 18})
34153414

34163415

3416+
def test_reclayer_optimize_out_cum_concat_gen_self_att():
3417+
new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="cum_concat_new_dim")
3418+
n_key = 5
3419+
n_value = 7
3420+
check_reclayer_optimize_out(
3421+
{"class": "linear", "from": "att", "activation": None},
3422+
{
3423+
# This is very much the vanilla self attention,
3424+
# implemented via the new generic way.
3425+
# See https://github.com/rwth-i6/returnn/issues/391 for a long discussion.
3426+
# Commented shapes are always for the layers inside the loop (not optimized).
3427+
"qkv": {"class": "linear", "from": "data:source", "activation": None, "n_out": n_key * 2 + n_value}, # [B,2*K+V]
3428+
"qkv_split": {"class": "split", "from": "qkv", "size_splits": [n_key, n_key, n_value]},
3429+
"q": {"class": "copy", "from": "qkv_split/0"}, # [B,K]
3430+
"k": {"class": "copy", "from": "qkv_split/1"}, # [B,K]
3431+
"v": {"class": "copy", "from": "qkv_split/2"}, # [B,V]
3432+
# cum_concat here. Note that the optimized-out shape is not as you might expect [T,max(t),B,K],
3433+
# but instead using the optimized format, with extended dyn size on the special dim tag.
3434+
"k_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "k"}, # [t,B,K]
3435+
"v_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "v"}, # [t,B,V]
3436+
"energy": {
3437+
"class": "dot", "from": ["q", "k_accum"],
3438+
"red1": "static:-1", "red2": "static:-1",
3439+
"var1": None, "var2": new_dim}, # [B,t]
3440+
"att_weights": {"class": "softmax_over_spatial", "from": "energy", "axis": new_dim}, # [B,t]
3441+
"att": {
3442+
"class": "dot", "from": ["att_weights", "v_accum"],
3443+
"red1": new_dim, "red2": new_dim,
3444+
"var1": None, "var2": "static:-1"}, # [B,V]
3445+
})
3446+
3447+
34173448
def test_reclayer_optimize_out_dot():
34183449
# Used for multi-head dot-attention.
34193450
AttNumHeads = 4

0 commit comments

Comments
 (0)