@@ -3310,34 +3310,39 @@ def test_rec_subnet_simple_rnn():
3310
3310
print ("rnn_cell also fine." )
3311
3311
3312
3312
3313
- def check_reclayer_optimize_out (subnet_layer_dict , other_subnet_layers = None , shared_base_net = None , rtol = 1e-4 ):
3313
+ def check_reclayer_optimize_out (
3314
+ subnet_layer_dict , other_subnet_layers = None , shared_base_net = None ,
3315
+ n_in = 13 , n_out = NotSpecified ,
3316
+ rtol = 1e-4 ):
3314
3317
"""
3315
3318
:param dict[str] subnet_layer_dict: opts for the output layer inside the rec-layer subnet
3316
3319
:param dict[str,dict[str]] other_subnet_layers: other layers for the rec-layer subnet
3317
3320
:param dict[str,dict[str]] shared_base_net:
3321
+ :param int n_in:
3322
+ :param int|NotSpecified|None n_out:
3318
3323
:param float rtol: for the final comparison check
3319
3324
"""
3320
3325
subnet_layer_dict = subnet_layer_dict .copy ()
3321
- n_in = 13
3322
- n_out = subnet_layer_dict .get ("n_out" , 17 )
3326
+ if n_out is NotSpecified :
3327
+ n_out = subnet_layer_dict .get ("n_out" , 17 )
3323
3328
n_batch = 5
3324
3329
n_time = 7
3325
- subnet_layer_dict ["n_out" ] = n_out
3326
3330
subnet_layer_dict .setdefault ("from" , ["data:source" ])
3327
3331
rec_layer_dict = {
3328
3332
"class" : "rec" ,
3329
3333
"from" : ["data" ],
3330
3334
"unit" : {"output" : subnet_layer_dict },
3331
- "n_out" : n_out ,
3332
3335
"is_output_layer" : True
3333
3336
}
3337
+ if n_out is not None :
3338
+ subnet_layer_dict ["n_out" ] = n_out
3339
+ rec_layer_dict ["n_out" ] = n_out
3334
3340
if other_subnet_layers :
3335
3341
assert "output" not in other_subnet_layers
3336
3342
rec_layer_dict ["unit" ].update (other_subnet_layers )
3337
3343
config = Config ({
3338
3344
"debug_print_layer_output_template" : True ,
3339
- "num_inputs" : n_in ,
3340
- "num_outputs" : n_out
3345
+ "extern_data" : {"data" : {"dim" : n_in }},
3341
3346
})
3342
3347
from returnn .tf .layers .rec import _SubnetworkRecCell
3343
3348
with make_scope () as session :
@@ -3414,6 +3419,11 @@ def test_reclayer_optimize_out_selfatt_left():
3414
3419
"class" : "self_attention" , "attention_left_only" : True , "num_heads" : 2 , "total_key_dim" : 6 , "n_out" : 18 })
3415
3420
3416
3421
3422
+ def test_reclayer_optimize_out_cum_concat ():
3423
+ new_dim = DimensionTag (kind = DimensionTag .Types .Spatial , description = "cum_concat_new_dim" )
3424
+ check_reclayer_optimize_out ({"class" : "cum_concat" , "new_dim" : new_dim }, n_in = 13 , n_out = None )
3425
+
3426
+
3417
3427
def test_reclayer_optimize_out_dot ():
3418
3428
# Used for multi-head dot-attention.
3419
3429
AttNumHeads = 4
0 commit comments