@@ -1233,7 +1233,7 @@ def test_attention_no_encoder_dependency():
1233
1233
'n_out' : 4 , 'padding' : 'same' },
1234
1234
"location_feedback" : {'class' : 'linear' , 'from' : ['convolved_att' ], 'n_out' : 6 , 'activation' : None },
1235
1235
"att_energy_in" : {'class' : 'combine' , 'kind' : 'add' , 'from' : ['location_feedback' , 's_transformed' ]},
1236
- "c" : {"class" : "generic_attention" , "base" : "base:encoder" , "weights" : "att_weights" },
1236
+ "c" : {"class" : "generic_attention" , "base" : "base:encoder" , "weights" : "att_weights" , "auto_squeeze" : True },
1237
1237
},
1238
1238
},
1239
1239
"decision" : {"class" : "decide" , "from" : ["output" ], "loss" : "edit_distance" }
@@ -1345,7 +1345,7 @@ def test_attention_convolutional_feedback_variant1():
1345
1345
"location_feedback" : {'class' : 'linear' , 'from' : ['convolved_att' ], 'n_out' : 6 , 'activation' : None },
1346
1346
"att_energy_in" : {'class' : 'combine' , 'kind' : 'add' , 'from' : [
1347
1347
'base:enc_transformed' , 'location_feedback' , 's_transformed' ]},
1348
- "c" : {"class" : "generic_attention" , "base" : "base:encoder" , "weights" : "att_weights" },
1348
+ "c" : {"class" : "generic_attention" , "base" : "base:encoder" , "weights" : "att_weights" , "auto_squeeze" : True },
1349
1349
}
1350
1350
1351
1351
check_attention_variant (recurrent_unit_dict )
@@ -1373,7 +1373,7 @@ def test_attention_convolutional_feedback_variant2():
1373
1373
"location_feedback" : {'class' : 'linear' , 'from' : ['convolved_att' ], 'n_out' : 6 , 'activation' : None },
1374
1374
"att_energy_in" : {'class' : 'combine' , 'kind' : 'add' , 'from' : [
1375
1375
'base:enc_transformed' , 'location_feedback' , 's_transformed' ]},
1376
- "c" : {"class" : "generic_attention" , "base" : "base:encoder" , "weights" : "att_weights" },
1376
+ "c" : {"class" : "generic_attention" , "base" : "base:encoder" , "weights" : "att_weights" , "auto_squeeze" : True },
1377
1377
}
1378
1378
1379
1379
check_attention_variant (recurrent_unit_dict )
@@ -1412,7 +1412,7 @@ def test_attention_convolutional_feedback_variant3():
1412
1412
"location_feedback" : {'class' : 'linear' , 'from' : ['convolved_att' ], 'n_out' : 6 , 'activation' : None },
1413
1413
"att_energy_in" : {'class' : 'combine' , 'kind' : 'add' , 'from' : [
1414
1414
'base:enc_transformed' , 'location_feedback' , 's_transformed' ]},
1415
- "c" : {"class" : "generic_attention" , "base" : "base:encoder" , "weights" : "att_weights" },
1415
+ "c" : {"class" : "generic_attention" , "base" : "base:encoder" , "weights" : "att_weights" , "auto_squeeze" : True },
1416
1416
}
1417
1417
1418
1418
check_attention_variant (recurrent_unit_dict )
@@ -2135,7 +2135,7 @@ def test_rec_subnet_construct_1():
2135
2135
"accum_att_weights" : {"class" : "eval" , "from" : ["prev:accum_att_weights" , "att_weights" , "base:inv_fertility" ],
2136
2136
"eval" : "source(0) + source(1) * source(2) * 0.5" ,
2137
2137
"out_type" : {"dim" : 1 , "shape" : (None , 1 )}},
2138
- "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" },
2138
+ "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" , "auto_squeeze" : True },
2139
2139
"s" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["target_embed" , "att" ], "n_out" : 10 },
2140
2140
"s2" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["s" ], "n_out" : 10 },
2141
2141
"readout_in" : {"class" : "linear" , "from" : ["prev:s2" , "prev:target_embed" , "att" ], "activation" : None , "n_out" : 10 },
@@ -2192,7 +2192,7 @@ def test_rec_subnet_construct_2():
2192
2192
"accum_att_weights" : {"class" : "eval" , "from" : ["prev:accum_att_weights" , "att_weights" , "base:inv_fertility" ],
2193
2193
"eval" : "source(0) + source(1) * source(2) * 0.5" ,
2194
2194
"out_type" : {"dim" : 1 , "shape" : (None , 1 )}},
2195
- "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" },
2195
+ "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" , "auto_squeeze" : True },
2196
2196
"s" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["target_embed" , "att" ], "n_out" : 10 },
2197
2197
"s2" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["s" ], "n_out" : 10 },
2198
2198
"readout_in" : {"class" : "linear" , "from" : ["prev:s2" , "prev:target_embed" , "att" ], "activation" : None , "n_out" : 10 },
@@ -2255,7 +2255,7 @@ def test_rec_subnet_construct_3():
2255
2255
"accum_att_weights" : {"class" : "eval" , "from" : ["prev:accum_att_weights" , "att_weights" , "base:inv_fertility" ],
2256
2256
"eval" : "source(0) + source(1) * source(2) * 0.5" ,
2257
2257
"out_type" : {"dim" : 1 , "shape" : (None , 1 )}},
2258
- "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" },
2258
+ "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" , "auto_squeeze" : True },
2259
2259
"s" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["target_embed" , "att" ], "n_out" : 10 },
2260
2260
"s2" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["prev:s" , "prev:target_embed" , "att" ], "n_out" : 10 },
2261
2261
"readout_in" : {"class" : "linear" , "from" : ["s2" ], "activation" : None , "n_out" : 10 },
@@ -2288,9 +2288,9 @@ def test_rec_subnet_eval_init_out_apply0():
2288
2288
# (also defined by num_inputs & num_outputs)
2289
2289
beam_size = 3
2290
2290
AttNumHeads = 2
2291
- EncKeyTotalDim = AttNumHeads * 2
2291
+ EncKeyTotalDim = AttNumHeads * 5
2292
2292
EncKeyPerHeadDim = EncKeyTotalDim // AttNumHeads
2293
- EncValueTotalDim = AttNumHeads * 2
2293
+ EncValueTotalDim = AttNumHeads * 5
2294
2294
EncValuePerHeadDim = EncValueTotalDim // AttNumHeads
2295
2295
network = {
2296
2296
"lstm0_fw" : {"class" : "rec" , "unit" : "nativelstm2" , "n_out" : 2 , "direction" : 1 , "from" : "data:data" },
@@ -2334,7 +2334,7 @@ def test_rec_subnet_eval_init_out_apply0():
2334
2334
"eval" : "source(0) + source(1) * source(2) * 0.5" ,
2335
2335
"out_type" : {"dim" : 1 , "shape" : (None , 1 )}, "initial_output" : "apply(0)" }, # (B, enc-T, 1)
2336
2336
"att0" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:enc_value" }, # (B, H, V)
2337
- "att" : {"class" : "merge_dims" , "axes" : "except_batch" , "from" : [ "att0" ] }, # (B, H*V)
2337
+ "att" : {"class" : "merge_dims" , "axes" : [ "dim:%i" % AttNumHeads , "dim:%i" % EncValuePerHeadDim ], " from" : "att0" }, # (B, H*V)
2338
2338
2339
2339
"s" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["target_embed" , "att" ], "n_out" : 2 }, # transform
2340
2340
"readout_in" : {"class" : "linear" , "from" : ["prev:s" , "prev:target_embed" , "att" ], "activation" : None ,
@@ -2768,13 +2768,13 @@ def custom_construction_algo(idx, net_dict):
2768
2768
def test_net_safe_log_to_log_softmax ():
2769
2769
n_out = 5
2770
2770
net_dict = {
2771
- "ff_in_window" : {"class" : "window" , "window_size" : 3 , "from" : "data:data" }, # (B,T,3 ,3)
2772
- "ff_in" : {"class" : "merge_dims" , "axes" : "except_time " , "from" : [ "ff_in_window" ] }, # (B,T,9)
2773
- "ff0" : {"class" : "hidden" , "activation" : "relu" , "n_out" : 8 , "L2" : 0.01 , "from" : [ "ff_in" ] }, # (B,T,8)
2774
- "ff_out" : {"class" : "softmax" , "n_out" : n_out , "from" : [ "ff0" ] }, # (B,T,5)
2771
+ "ff_in_window" : {"class" : "window" , "window_size" : 4 , "from" : "data:data" }, # (B,T,4 ,3)
2772
+ "ff_in" : {"class" : "merge_dims" , "axes" : [ "dim:3 " , "dim:4" ], " from" : "ff_in_window" }, # (B,T,9)
2773
+ "ff0" : {"class" : "hidden" , "activation" : "relu" , "n_out" : 8 , "L2" : 0.01 , "from" : "ff_in" }, # (B,T,8)
2774
+ "ff_out" : {"class" : "softmax" , "n_out" : n_out , "from" : "ff0" }, # (B,T,5)
2775
2775
"ff_out_prior" : {
2776
2776
"class" : "accumulate_mean" , "exp_average" : 0.001 ,
2777
- "is_prob_distribution" : True , "from" : [ "ff_out" ] }, # (5,)
2777
+ "is_prob_distribution" : True , "from" : "ff_out" }, # (5,)
2778
2778
"output" : {
2779
2779
"class" : "combine" , "kind" : "eval" , "from" : ["ff_out" , "ff_out_prior" ],
2780
2780
"eval" : "safe_log(source(0)) - safe_log(source(1))" ,
@@ -2826,7 +2826,7 @@ def test_preload_from_files():
2826
2826
"class" : "linear" , "activation" : None , "n_out" : n_hidden , "from" : "data:data" ,
2827
2827
'bias_init' : 1.0 , 'forward_weights_init' : 'orthogonal' },
2828
2828
"output" : {
2829
- "class" : "linear" , "activation" : None , "n_out" : n_out , "from" : [ "l1" ] ,
2829
+ "class" : "linear" , "activation" : None , "n_out" : n_out , "from" : "l1" ,
2830
2830
'bias_init' : 2.0 , 'forward_weights_init' : 'orthogonal' }
2831
2831
}
2832
2832
})
@@ -3366,7 +3366,7 @@ def test_attention_forward_hdf_then_unflatten_2d():
3366
3366
# (B, enc-T, 1)
3367
3367
"energy" : {"class" : "linear" , "activation" : None , "with_bias" : False , "from" : ["energy_tanh" ], "n_out" : 1 },
3368
3368
"att_weights" : {"class" : "softmax_over_spatial" , "from" : ["energy" ], "is_output_layer" : True }, # (B, enc-T, 1)
3369
- "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" },
3369
+ "att" : {"class" : "generic_attention" , "weights" : "att_weights" , "base" : "base:encoder" , "auto_squeeze" : True },
3370
3370
"s" : {"class" : "rnn_cell" , "unit" : "LSTMBlock" , "from" : ["prev:target_embed" , "prev:att" ], "n_out" : 10 },
3371
3371
"readout_in" : {"class" : "linear" , "from" : ["s" , "prev:target_embed" , "att" ], "activation" : None , "n_out" : 10 },
3372
3372
"readout" : {"class" : "reduce_out" , "mode" : "max" , "num_pieces" : 2 , "from" : ["readout_in" ]},
0 commit comments