Skip to content

Commit 2384c1b

Browse files
committed
test_reclayer_optimize_out_cum_concat_gen_self_att
1 parent 09c4cb8 commit 2384c1b

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

tests/test_TFNetworkRecLayer.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -3345,8 +3345,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33453345
rec_layer_dict["unit"].update(other_subnet_layers)
33463346
config = Config({
33473347
"debug_print_layer_output_template": True,
3348-
"num_inputs": n_in,
3349-
"num_outputs": n_out
3348+
"extern_data": {"data": {"dim": n_in}},
33503349
})
33513350
from returnn.tf.layers.rec import _SubnetworkRecCell
33523351
with make_scope() as session:
@@ -3423,6 +3422,38 @@ def test_reclayer_optimize_out_selfatt_left():
34233422
"class": "self_attention", "attention_left_only": True, "num_heads": 2, "total_key_dim": 6, "n_out": 18})
34243423

34253424

3425+
def test_reclayer_optimize_out_cum_concat_gen_self_att():
3426+
new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="cum_concat_new_dim")
3427+
n_key = 5
3428+
n_value = 7
3429+
check_reclayer_optimize_out(
3430+
{"class": "linear", "from": "att", "activation": None},
3431+
{
3432+
# This is very much the vanilla self attention,
3433+
# implemented via the new generic way.
3434+
# See https://github.com/rwth-i6/returnn/issues/391 for a long discussion.
3435+
# Commented shapes are always for the layers inside the loop (not optimized).
3436+
"qkv": {"class": "linear", "from": "data:source", "activation": None, "n_out": n_key * 2 + n_value}, # [B,2*K+V]
3437+
"qkv_split": {"class": "split", "from": "qkv", "size_splits": [n_key, n_key, n_value]},
3438+
"q": {"class": "copy", "from": "qkv_split/0"}, # [B,K]
3439+
"k": {"class": "copy", "from": "qkv_split/1"}, # [B,K]
3440+
"v": {"class": "copy", "from": "qkv_split/2"}, # [B,V]
3441+
# cum_concat here. Note that the optimized-out shape is not as you might expect [T,max(t),B,K],
3442+
# but instead using the optimized format, with extended dyn size on the special dim tag.
3443+
"k_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "k"}, # [t,B,K]
3444+
"v_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "v"}, # [t,B,V]
3445+
"energy": {
3446+
"class": "dot", "from": ["q", "k_accum"],
3447+
"red1": "static:-1", "red2": "static:-1",
3448+
"var1": None, "var2": new_dim}, # [B,t]
3449+
"att_weights": {"class": "softmax_over_spatial", "from": "energy", "axis": new_dim}, # [B,t]
3450+
"att": {
3451+
"class": "dot", "from": ["att_weights", "v_accum"],
3452+
"red1": new_dim, "red2": new_dim,
3453+
"var1": None, "var2": "static:-1"}, # [B,V]
3454+
})
3455+
3456+
34263457
def test_reclayer_optimize_out_dot():
34273458
# Used for multi-head dot-attention.
34283459
AttNumHeads = 4

0 commit comments

Comments
 (0)