Skip to content

Commit a406d3f

Browse files
committed
test_reclayer_optimize_out_cum_concat_gen_self_att
#589, #391
1 parent 4661bbe commit a406d3f

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

tests/test_TFNetworkRecLayer.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -3385,8 +3385,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33853385
rec_layer_dict["unit"].update(other_subnet_layers)
33863386
config = Config({
33873387
"debug_print_layer_output_template": True,
3388-
"num_inputs": n_in,
3389-
"num_outputs": n_out
3388+
"extern_data": {"data": {"dim": n_in}},
33903389
})
33913390
from returnn.tf.layers.rec import _SubnetworkRecCell
33923391
with make_scope() as session:
@@ -3463,6 +3462,40 @@ def test_reclayer_optimize_out_selfatt_left():
34633462
"class": "self_attention", "attention_left_only": True, "num_heads": 2, "total_key_dim": 6, "n_out": 18})
34643463

34653464

3465+
def test_reclayer_optimize_out_cum_concat_gen_self_att():
3466+
new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="cum_concat_new_dim")
3467+
n_key = 5
3468+
n_value = 7
3469+
check_reclayer_optimize_out(
3470+
{"class": "linear", "from": "att", "activation": None},
3471+
{
3472+
# This is very much the vanilla self attention,
3473+
# implemented via the new generic way.
3474+
# See https://github.com/rwth-i6/returnn/issues/391 for a long discussion.
3475+
# Commented shapes are always for the layers inside the loop (not optimized).
3476+
"qkv": {"class": "linear", "from": "data:source", "activation": None, "n_out": n_key * 2 + n_value}, # [B,2*K+V]
3477+
"qkv_split": {"class": "split", "from": "qkv", "size_splits": [n_key, n_key, n_value]},
3478+
"q": {"class": "copy", "from": "qkv_split/0"}, # inside [B,K]. optimized out [T,B,K]
3479+
"k": {"class": "copy", "from": "qkv_split/1"}, # inside [B,K]. optimized out [T,B,K]
3480+
"v": {"class": "copy", "from": "qkv_split/2"}, # inside [B,V]. optimized out [T,B,V]
3481+
# cum_concat here. Note that the optimized-out shape is not as you might expect [T,max(t),B,K],
3482+
# but instead using the optimized format, with extended dyn size on the special dim tag,
3483+
# i.e. [t*,B,K], representing [T,t*,B,K].
3484+
"k_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "k"}, # inside [t,B,K]. opt out [t*,B,K]
3485+
"v_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "v"}, # inside [t,B,V]. opt out [t*,B,K]
3486+
"energy": {
3487+
"class": "dot", "from": ["q", "k_accum"],
3488+
"red1": "static:-1", "red2": "static:-1",
3489+
"var1": None, "var2": new_dim}, # inside [B,t]. optimized out [T,B,t*]
3490+
"att_weights": {
3491+
"class": "softmax_over_spatial", "from": "energy", "axis": new_dim}, # inside [B,t]. opt out [T,B,t*]
3492+
"att": {
3493+
"class": "dot", "from": ["att_weights", "v_accum"],
3494+
"red1": new_dim, "red2": new_dim,
3495+
"var1": None, "var2": "static:-1"}, # inside [B,V]. opt out [T,B,V]
3496+
})
3497+
3498+
34663499
def test_reclayer_optimize_out_dot():
34673500
# Used for multi-head dot-attention.
34683501
AttNumHeads = 4

0 commit comments

Comments
 (0)