Skip to content

Commit c1e55d2

Browse files
committed
test_reclayer_optimize_out_cum_concat wip
1 parent 8717216 commit c1e55d2

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

tests/test_TFNetworkRecLayer.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -3310,34 +3310,39 @@ def test_rec_subnet_simple_rnn():
33103310
print("rnn_cell also fine.")
33113311

33123312

3313-
def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, shared_base_net=None, rtol=1e-4):
3313+
def check_reclayer_optimize_out(
3314+
subnet_layer_dict, other_subnet_layers=None, shared_base_net=None,
3315+
n_in=13, n_out=NotSpecified,
3316+
rtol=1e-4):
33143317
"""
33153318
:param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet
33163319
:param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet
33173320
:param dict[str,dict[str]] shared_base_net:
3321+
:param int n_in:
3322+
:param int|NotSpecified|None n_out:
33183323
:param float rtol: for the final comparison check
33193324
"""
33203325
subnet_layer_dict = subnet_layer_dict.copy()
3321-
n_in = 13
3322-
n_out = subnet_layer_dict.get("n_out", 17)
3326+
if n_out is NotSpecified:
3327+
n_out = subnet_layer_dict.get("n_out", 17)
33233328
n_batch = 5
33243329
n_time = 7
3325-
subnet_layer_dict["n_out"] = n_out
33263330
subnet_layer_dict.setdefault("from", ["data:source"])
33273331
rec_layer_dict = {
33283332
"class": "rec",
33293333
"from": ["data"],
33303334
"unit": {"output": subnet_layer_dict},
3331-
"n_out": n_out,
33323335
"is_output_layer": True
33333336
}
3337+
if n_out is not None:
3338+
subnet_layer_dict["n_out"] = n_out
3339+
rec_layer_dict["n_out"] = n_out
33343340
if other_subnet_layers:
33353341
assert "output" not in other_subnet_layers
33363342
rec_layer_dict["unit"].update(other_subnet_layers)
33373343
config = Config({
33383344
"debug_print_layer_output_template": True,
3339-
"num_inputs": n_in,
3340-
"num_outputs": n_out
3345+
"extern_data": {"data": {"dim": n_in}},
33413346
})
33423347
from returnn.tf.layers.rec import _SubnetworkRecCell
33433348
with make_scope() as session:
@@ -3414,6 +3419,11 @@ def test_reclayer_optimize_out_selfatt_left():
34143419
"class": "self_attention", "attention_left_only": True, "num_heads": 2, "total_key_dim": 6, "n_out": 18})
34153420

34163421

3422+
def test_reclayer_optimize_out_cum_concat():
3423+
new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="cum_concat_new_dim")
3424+
check_reclayer_optimize_out({"class": "cum_concat", "new_dim": new_dim}, n_in=13, n_out=None)
3425+
3426+
34173427
def test_reclayer_optimize_out_dot():
34183428
# Used for multi-head dot-attention.
34193429
AttNumHeads = 4

0 commit comments

Comments
 (0)