forked from lvapeab/nmt-keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_zoo.py
670 lines (598 loc) · 41.2 KB
/
model_zoo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
import logging
import os
from keras.layers import *
from keras.models import model_from_json, Model
from keras.optimizers import Adam, RMSprop, Nadam, Adadelta, SGD, Adagrad, Adamax
from keras.regularizers import l2, AlphaRegularizer
from keras_wrapper.cnn_model import Model_Wrapper
from keras_wrapper.extra.regularize import Regularize
class TranslationModel(Model_Wrapper):
"""
Translation model class. Instance of the Model_Wrapper class (see staged_keras_wrapper).
:param params: all hyperparams of the model.
:param model_type: network name type (corresponds to any method defined in the section 'MODELS' of this class).
Only valid if 'structure_path' == None.
:param verbose: set to 0 if you don't want the model to output informative messages
:param structure_path: path to a Keras' model json file.
If we speficy this parameter then 'type' will be only an informative parameter.
:param weights_path: path to the pre-trained weights file (if None, then it will be randomly initialized)
:param model_name: optional name given to the network
(if None, then it will be assigned to current time as its name)
:param vocabularies: vocabularies used for word embedding
:param store_path: path to the folder where the temporal model packups will be stored
:param set_optimizer: Compile optimizer or not.
:param clear_dirs: Clean model directories or not.
"""
def __init__(self, params, model_type='Translation_Model', verbose=1, structure_path=None, weights_path=None,
model_name=None, vocabularies=None, store_path=None, set_optimizer=True, clear_dirs=True):
"""
Translation_Model object constructor.
:param params: all hyperparams of the model.
:param model_type: network name type (corresponds to any method defined in the section 'MODELS' of this class).
Only valid if 'structure_path' == None.
:param verbose: set to 0 if you don't want the model to output informative messages
:param structure_path: path to a Keras' model json file.
If we speficy this parameter then 'type' will be only an informative parameter.
:param weights_path: path to the pre-trained weights file (if None, then it will be randomly initialized)
:param model_name: optional name given to the network
(if None, then it will be assigned to current time as its name)
:param vocabularies: vocabularies used for word embedding
:param store_path: path to the folder where the temporal model packups will be stored
:param set_optimizer: Compile optimizer or not.
:param clear_dirs: Clean model directories or not.
"""
super(TranslationModel, self).__init__(type=model_type, model_name=model_name,
silence=verbose == 0, models_path=store_path, inheritance=True)
self.__toprint = ['_model_type', 'name', 'model_path', 'verbose']
self.verbose = verbose
self._model_type = model_type
self.params = params
self.vocabularies = vocabularies
self.ids_inputs = params['INPUTS_IDS_MODEL']
self.ids_outputs = params['OUTPUTS_IDS_MODEL']
self.return_alphas = params['COVERAGE_PENALTY'] or params['POS_UNK']
# Sets the model name and prepares the folders for storing the models
self.setName(model_name, models_path=store_path, clear_dirs=clear_dirs)
# Prepare source word embedding
if params['SRC_PRETRAINED_VECTORS'] is not None:
if self.verbose > 0:
logging.info("<<< Loading pretrained word vectors from: " + params['SRC_PRETRAINED_VECTORS'] + " >>>")
src_word_vectors = np.load(os.path.join(params['SRC_PRETRAINED_VECTORS'])).item()
self.src_embedding_weights = np.random.rand(params['INPUT_VOCABULARY_SIZE'],
params['SOURCE_TEXT_EMBEDDING_SIZE'])
for word, index in self.vocabularies[self.ids_inputs[0]]['words2idx'].iteritems():
if src_word_vectors.get(word) is not None:
self.src_embedding_weights[index, :] = src_word_vectors[word]
self.src_embedding_weights = [self.src_embedding_weights]
self.src_embedding_weights_trainable = params['SRC_PRETRAINED_VECTORS_TRAINABLE']
del src_word_vectors
else:
self.src_embedding_weights = None
self.src_embedding_weights_trainable = True
# Prepare target word embedding
if params['TRG_PRETRAINED_VECTORS'] is not None:
if self.verbose > 0:
logging.info("<<< Loading pretrained word vectors from: " + params['TRG_PRETRAINED_VECTORS'] + " >>>")
trg_word_vectors = np.load(os.path.join(params['TRG_PRETRAINED_VECTORS'])).item()
self.trg_embedding_weights = np.random.rand(params['OUTPUT_VOCABULARY_SIZE'],
params['TARGET_TEXT_EMBEDDING_SIZE'])
for word, index in self.vocabularies[self.ids_outputs[0]]['words2idx'].iteritems():
if trg_word_vectors.get(word) is not None:
self.trg_embedding_weights[index, :] = trg_word_vectors[word]
self.trg_embedding_weights = [self.trg_embedding_weights]
self.trg_embedding_weights_trainable = params['TRG_PRETRAINED_VECTORS_TRAINABLE']
del trg_word_vectors
else:
self.trg_embedding_weights = None
self.trg_embedding_weights_trainable = True
# Prepare model
if structure_path:
# Load a .json model
if self.verbose > 0:
logging.info("<<< Loading model structure from file " + structure_path + " >>>")
self.model = model_from_json(open(structure_path).read())
else:
# Build model from scratch
if hasattr(self, model_type):
if self.verbose > 0:
logging.info("<<< Building " + model_type + " Translation_Model >>>")
eval('self.' + model_type + '(params)')
else:
raise Exception('Translation_Model model_type "' + model_type + '" is not implemented.')
# Load weights from file
if weights_path:
if self.verbose > 0:
logging.info("<<< Loading weights from file " + weights_path + " >>>")
self.model.load_weights(weights_path)
# Print information of self
if verbose > 0:
print str(self)
self.model.summary()
if set_optimizer:
self.setOptimizer()
def setParams(self, params):
self.params = params
def setOptimizer(self, **kwargs):
"""
Sets and compiles a new optimizer for the Translation_Model.
:param kwargs:
:return:
"""
# compile differently depending if our model is 'Sequential' or 'Graph'
if self.verbose > 0:
logging.info("Preparing optimizer: %s [LR: %s - LOSS: %s] and compiling." %
(str(self.params['OPTIMIZER']), str(self.params.get('LR', 0.01)),
str(self.params.get('LOSS', 'categorical_crossentropy'))))
if self.params['OPTIMIZER'].lower() == 'sgd':
optimizer = SGD(lr=self.params.get('LR', 0.01),
momentum=self.params.get('MOMENTUM', 0.0),
decay=self.params.get('LR_OPTIMIZER_DECAY', 0.0),
nesterov=self.params.get('NESTEROV_MOMENTUM', False),
clipnorm=self.params.get('CLIP_C', 0.),
clipvalue=self.params.get('CLIP_V', 0.), )
elif self.params['OPTIMIZER'].lower() == 'rsmprop':
optimizer = RMSprop(lr=self.params.get('LR', 0.001),
rho=self.params.get('RHO', 0.9),
decay=self.params.get('LR_OPTIMIZER_DECAY', 0.0),
clipnorm=self.params.get('CLIP_C', 0.),
clipvalue=self.params.get('CLIP_V', 0.))
elif self.params['OPTIMIZER'].lower() == 'adagrad':
optimizer = Adagrad(lr=self.params.get('LR', 0.01),
decay=self.params.get('LR_OPTIMIZER_DECAY', 0.0),
clipnorm=self.params.get('CLIP_C', 0.),
clipvalue=self.params.get('CLIP_V', 0.))
elif self.params['OPTIMIZER'].lower() == 'adadelta':
optimizer = Adadelta(lr=self.params.get('LR', 1.0),
rho=self.params.get('RHO', 0.9),
decay=self.params.get('LR_OPTIMIZER_DECAY', 0.0),
clipnorm=self.params.get('CLIP_C', 0.),
clipvalue=self.params.get('CLIP_V', 0.))
elif self.params['OPTIMIZER'].lower() == 'adam':
optimizer = Adam(lr=self.params.get('LR', 0.001),
beta_1=self.params.get('BETA_1', 0.9),
beta_2=self.params.get('BETA_2', 0.999),
decay=self.params.get('LR_OPTIMIZER_DECAY', 0.0),
clipnorm=self.params.get('CLIP_C', 0.),
clipvalue=self.params.get('CLIP_V', 0.))
elif self.params['OPTIMIZER'].lower() == 'adamax':
optimizer = Adamax(lr=self.params.get('LR', 0.002),
beta_1=self.params.get('BETA_1', 0.9),
beta_2=self.params.get('BETA_2', 0.999),
decay=self.params.get('LR_OPTIMIZER_DECAY', 0.0),
clipnorm=self.params.get('CLIP_C', 0.),
clipvalue=self.params.get('CLIP_V', 0.))
elif self.params['OPTIMIZER'].lower() == 'nadam':
optimizer = Nadam(lr=self.params.get('LR', 0.002),
beta_1=self.params.get('BETA_1', 0.9),
beta_2=self.params.get('BETA_2', 0.999),
schedule_decay=self.params.get('LR_OPTIMIZER_DECAY', 0.0),
clipnorm=self.params.get('CLIP_C', 0.),
clipvalue=self.params.get('CLIP_V', 0.))
else:
logging.info('\tWARNING: The modification of the LR is not implemented for the chosen optimizer.')
optimizer = eval(self.params['OPTIMIZER'])
self.model.compile(optimizer=optimizer, loss=self.params['LOSS'],
metrics=self.params.get('KERAS_METRICS', []),
sample_weight_mode='temporal' if self.params['SAMPLE_WEIGHTS'] else None)
def __str__(self):
"""
Plots basic model information.
:return: String containing model information.
"""
obj_str = '-----------------------------------------------------------------------------------\n'
class_name = self.__class__.__name__
obj_str += '\t\t' + class_name + ' instance\n'
obj_str += '-----------------------------------------------------------------------------------\n'
# Print pickled attributes
for att in self.__toprint:
obj_str += att + ': ' + str(self.__dict__[att])
obj_str += '\n'
obj_str += '\n'
obj_str += 'MODEL params:\n'
obj_str += str(self.params)
obj_str += '\n'
obj_str += '-----------------------------------------------------------------------------------'
return obj_str
# ------------------------------------------------------- #
# PREDEFINED MODELS
# ------------------------------------------------------- #
def GroundHogModel(self, params):
"""
Neural machine translation with:
* BRNN encoder
* Attention mechansim on input sequence of annotations
* Conditional RNN for decoding
* Deep output layers:
* Context projected to output
* Last word projected to output
* Possibly deep encoder/decoder
See https://arxiv.org/abs/1409.0473 for an in-depth review of the model.
:param params: Dictionary of params (see config.py)
:return: None
"""
# 1. Source text input
src_text = Input(name=self.ids_inputs[0], batch_shape=tuple([None, None]), dtype='int32')
# 2. Encoder
# 2.1. Source word embedding
src_embedding = Embedding(params['INPUT_VOCABULARY_SIZE'], params['SOURCE_TEXT_EMBEDDING_SIZE'],
name='source_word_embedding',
embeddings_regularizer=l2(params['WEIGHT_DECAY']),
embeddings_initializer=params['INIT_FUNCTION'],
trainable=self.src_embedding_weights_trainable, weights=self.src_embedding_weights,
mask_zero=True)(src_text)
src_embedding = Regularize(src_embedding, params, name='src_embedding')
# 2.2. BRNN encoder (GRU/LSTM)
if params['BIDIRECTIONAL_ENCODER']:
annotations = Bidirectional(eval(params['ENCODER_RNN_TYPE'])(params['ENCODER_HIDDEN_SIZE'],
kernel_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
recurrent_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
bias_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
dropout=params['RECURRENT_INPUT_DROPOUT_P'],
recurrent_dropout=params[
'RECURRENT_DROPOUT_P'],
kernel_initializer=params['INIT_FUNCTION'],
recurrent_initializer=params['INNER_INIT'],
return_sequences=True),
name='bidirectional_encoder_' + params['ENCODER_RNN_TYPE'],
merge_mode='concat')(src_embedding)
else:
annotations = eval(params['ENCODER_RNN_TYPE'])(params['ENCODER_HIDDEN_SIZE'],
kernel_regularizer=l2(params['RECURRENT_WEIGHT_DECAY']),
recurrent_regularizer=l2(params['RECURRENT_WEIGHT_DECAY']),
bias_regularizer=l2(params['RECURRENT_WEIGHT_DECAY']),
dropout=params['RECURRENT_INPUT_DROPOUT_P'],
recurrent_dropout=params['RECURRENT_DROPOUT_P'],
kernel_initializer=params['INIT_FUNCTION'],
recurrent_initializer=params['INNER_INIT'],
return_sequences=True,
name='encoder_' + params['ENCODER_RNN_TYPE'])(src_embedding)
annotations = Regularize(annotations, params, name='annotations')
# 2.3. Potentially deep encoder
for n_layer in range(1, params['N_LAYERS_ENCODER']):
if params['BIDIRECTIONAL_DEEP_ENCODER']:
current_annotations = Bidirectional(eval(params['ENCODER_RNN_TYPE'])(params['ENCODER_HIDDEN_SIZE'],
kernel_regularizer=l2(
params[
'RECURRENT_WEIGHT_DECAY']),
recurrent_regularizer=l2(
params[
'RECURRENT_WEIGHT_DECAY']),
bias_regularizer=l2(
params[
'RECURRENT_WEIGHT_DECAY']),
dropout=params[
'RECURRENT_INPUT_DROPOUT_P'],
recurrent_dropout=params[
'RECURRENT_DROPOUT_P'],
kernel_initializer=params[
'INIT_FUNCTION'],
recurrent_initializer=params[
'INNER_INIT'],
return_sequences=True,
),
merge_mode='concat',
name='bidirectional_encoder_' + str(n_layer))(annotations)
else:
current_annotations = eval(params['ENCODER_RNN_TYPE'])(params['ENCODER_HIDDEN_SIZE'],
kernel_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
recurrent_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
bias_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
dropout=params['RECURRENT_INPUT_DROPOUT_P'],
recurrent_dropout=params['RECURRENT_DROPOUT_P'],
kernel_initializer=params['INIT_FUNCTION'],
recurrent_initializer=params['INNER_INIT'],
return_sequences=True,
name='encoder_' + str(n_layer))(annotations)
current_annotations = Regularize(current_annotations, params, name='annotations_' + str(n_layer))
annotations = Add()([annotations, current_annotations])
# 3. Decoder
# 3.1.1. Previously generated words as inputs for training -> Teacher forcing
next_words = Input(name=self.ids_inputs[1], batch_shape=tuple([None, None]), dtype='int32')
# 3.1.2. Target word embedding
state_below = Embedding(params['OUTPUT_VOCABULARY_SIZE'], params['TARGET_TEXT_EMBEDDING_SIZE'],
name='target_word_embedding',
embeddings_regularizer=l2(params['WEIGHT_DECAY']),
embeddings_initializer=params['INIT_FUNCTION'],
trainable=self.trg_embedding_weights_trainable, weights=self.trg_embedding_weights,
mask_zero=True)(next_words)
state_below = Regularize(state_below, params, name='state_below')
# 3.2. Decoder's RNN initialization perceptrons with ctx mean
ctx_mean = MaskedMean()(annotations)
annotations = MaskLayer()(annotations) # We may want the padded annotations
if len(params['INIT_LAYERS']) > 0:
for n_layer_init in range(len(params['INIT_LAYERS']) - 1):
ctx_mean = Dense(params['DECODER_HIDDEN_SIZE'], name='init_layer_%d' % n_layer_init,
kernel_initializer=params['INIT_FUNCTION'],
kernel_regularizer=l2(params['WEIGHT_DECAY']),
bias_regularizer=l2(params['WEIGHT_DECAY']),
activation=params['INIT_LAYERS'][n_layer_init]
)(ctx_mean)
ctx_mean = Regularize(ctx_mean, params, name='ctx' + str(n_layer_init))
initial_state = Dense(params['DECODER_HIDDEN_SIZE'], name='initial_state',
kernel_initializer=params['INIT_FUNCTION'],
kernel_regularizer=l2(params['WEIGHT_DECAY']),
bias_regularizer=l2(params['WEIGHT_DECAY']),
activation=params['INIT_LAYERS'][-1]
)(ctx_mean)
initial_state = Regularize(initial_state, params, name='initial_state')
input_attentional_decoder = [state_below, annotations, initial_state]
if 'LSTM' in params['DECODER_RNN_TYPE']:
initial_memory = Dense(params['DECODER_HIDDEN_SIZE'], name='initial_memory',
kernel_initializer=params['INIT_FUNCTION'],
kernel_regularizer=l2(params['WEIGHT_DECAY']),
bias_regularizer=l2(params['WEIGHT_DECAY']),
activation=params['INIT_LAYERS'][-1])(ctx_mean)
initial_memory = Regularize(initial_memory, params, name='initial_memory')
input_attentional_decoder.append(initial_memory)
else:
# Initialize to zeros vector
input_attentional_decoder = [state_below, annotations]
initial_state = ZeroesLayer(params['DECODER_HIDDEN_SIZE'])(ctx_mean)
input_attentional_decoder.append(initial_state)
if 'LSTM' in params['DECODER_RNN_TYPE']:
input_attentional_decoder.append(initial_state)
# 3.3. Attentional decoder
sharedAttRNNCond = eval('Att' + params['DECODER_RNN_TYPE'] + 'Cond')(params['DECODER_HIDDEN_SIZE'],
att_units=params.get('ATTENTION_SIZE', 0),
kernel_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
recurrent_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
conditional_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
bias_regularizer=l2(
params['RECURRENT_WEIGHT_DECAY']),
attention_context_wa_regularizer=l2(
params['WEIGHT_DECAY']),
attention_recurrent_regularizer=l2(
params['WEIGHT_DECAY']),
attention_context_regularizer=l2(
params['WEIGHT_DECAY']),
bias_ba_regularizer=l2(
params['WEIGHT_DECAY']),
dropout=params[
'RECURRENT_INPUT_DROPOUT_P'],
recurrent_dropout=params[
'RECURRENT_DROPOUT_P'],
conditional_dropout=params[
'RECURRENT_INPUT_DROPOUT_P'],
attention_dropout=params['DROPOUT_P'],
kernel_initializer=params['INIT_FUNCTION'],
recurrent_initializer=params['INNER_INIT'],
attention_context_initializer=params[
'INIT_ATT'],
return_sequences=True,
return_extra_variables=True,
return_states=True,
num_inputs=len(input_attentional_decoder),
name='decoder_Att' + params[
'DECODER_RNN_TYPE'] + 'Cond')
rnn_output = sharedAttRNNCond(input_attentional_decoder)
proj_h = rnn_output[0]
x_att = rnn_output[1]
alphas = rnn_output[2]
h_state = rnn_output[3]
if 'LSTM' in params['DECODER_RNN_TYPE']:
h_memory = rnn_output[4]
shared_Lambda_Permute = PermuteGeneral((1, 0, 2))
if params['DOUBLE_STOCHASTIC_ATTENTION_REG'] > 0:
alpha_regularizer = AlphaRegularizer(alpha_factor=params['DOUBLE_STOCHASTIC_ATTENTION_REG'])(alphas)
[proj_h, shared_reg_proj_h] = Regularize(proj_h, params, shared_layers=True, name='proj_h0')
# 3.4. Possibly deep decoder
shared_proj_h_list = []
shared_reg_proj_h_list = []
h_states_list = [h_state]
if 'LSTM' in params['DECODER_RNN_TYPE']:
h_memories_list = [h_memory]
for n_layer in range(1, params['N_LAYERS_DECODER']):
current_rnn_input = [proj_h, shared_Lambda_Permute(x_att), initial_state]
shared_proj_h_list.append(eval(params['DECODER_RNN_TYPE'].replace('Conditional', '') + 'Cond')(
params['DECODER_HIDDEN_SIZE'],
kernel_regularizer=l2(params['RECURRENT_WEIGHT_DECAY']),
recurrent_regularizer=l2(params['RECURRENT_WEIGHT_DECAY']),
conditional_regularizer=l2(params['RECURRENT_WEIGHT_DECAY']),
bias_regularizer=l2(params['RECURRENT_WEIGHT_DECAY']),
dropout=params['RECURRENT_DROPOUT_P'],
recurrent_dropout=params['RECURRENT_INPUT_DROPOUT_P'],
conditional_dropout=params['RECURRENT_INPUT_DROPOUT_P'],
kernel_initializer=params['INIT_FUNCTION'],
recurrent_initializer=params['INNER_INIT'],
return_sequences=True,
return_states=True,
num_inputs=len(current_rnn_input),
name='decoder_' + params['DECODER_RNN_TYPE'].replace(
'Conditional', '') + 'Cond' + str(n_layer)))
if 'LSTM' in params['DECODER_RNN_TYPE']:
current_rnn_input.append(initial_memory)
current_rnn_output = shared_proj_h_list[-1](current_rnn_input)
current_proj_h = current_rnn_output[0]
h_states_list.append(current_rnn_output[1])
if 'LSTM' in params['DECODER_RNN_TYPE']:
h_memories_list.append(current_rnn_output[2])
[current_proj_h, shared_reg_proj_h] = Regularize(current_proj_h, params, shared_layers=True,
name='proj_h' + str(n_layer))
shared_reg_proj_h_list.append(shared_reg_proj_h)
proj_h = Add()([proj_h, current_proj_h])
# 3.5. Skip connections between encoder and output layer
shared_FC_mlp = TimeDistributed(Dense(params['SKIP_VECTORS_HIDDEN_SIZE'],
kernel_initializer=params['INIT_FUNCTION'],
kernel_regularizer=l2(params['WEIGHT_DECAY']),
bias_regularizer=l2(params['WEIGHT_DECAY']),
activation='linear'),
name='logit_lstm')
out_layer_mlp = shared_FC_mlp(proj_h)
shared_FC_ctx = TimeDistributed(Dense(params['SKIP_VECTORS_HIDDEN_SIZE'],
kernel_initializer=params['INIT_FUNCTION'],
kernel_regularizer=l2(params['WEIGHT_DECAY']),
bias_regularizer=l2(params['WEIGHT_DECAY']),
activation='linear'),
name='logit_ctx')
out_layer_ctx = shared_FC_ctx(x_att)
out_layer_ctx = shared_Lambda_Permute(out_layer_ctx)
shared_FC_emb = TimeDistributed(Dense(params['SKIP_VECTORS_HIDDEN_SIZE'],
kernel_initializer=params['INIT_FUNCTION'],
kernel_regularizer=l2(params['WEIGHT_DECAY']),
bias_regularizer=l2(params['WEIGHT_DECAY']),
activation='linear'),
name='logit_emb')
out_layer_emb = shared_FC_emb(state_below)
[out_layer_mlp, shared_reg_out_layer_mlp] = Regularize(out_layer_mlp, params,
shared_layers=True, name='out_layer_mlp')
[out_layer_ctx, shared_reg_out_layer_ctx] = Regularize(out_layer_ctx, params,
shared_layers=True, name='out_layer_ctx')
[out_layer_emb, shared_reg_out_layer_emb] = Regularize(out_layer_emb, params,
shared_layers=True, name='out_layer_emb')
shared_additional_output_merge = eval(params['ADDITIONAL_OUTPUT_MERGE_MODE'])(name='additional_input')
additional_output = shared_additional_output_merge([out_layer_mlp, out_layer_ctx, out_layer_emb])
shared_activation_tanh = Activation('tanh')
out_layer = shared_activation_tanh(additional_output)
shared_deep_list = []
shared_reg_deep_list = []
# 3.6 Optional deep ouput layer
for i, (activation, dimension) in enumerate(params['DEEP_OUTPUT_LAYERS']):
shared_deep_list.append(TimeDistributed(Dense(dimension, activation=activation,
kernel_initializer=params['INIT_FUNCTION'],
kernel_regularizer=l2(params['WEIGHT_DECAY']),
bias_regularizer=l2(params['WEIGHT_DECAY']),
),
name=activation + '_%d' % i))
out_layer = shared_deep_list[-1](out_layer)
[out_layer, shared_reg_out_layer] = Regularize(out_layer,
params, shared_layers=True,
name='out_layer_' + str(activation) + '_%d' % i)
shared_reg_deep_list.append(shared_reg_out_layer)
# 3.7. Output layer: Softmax
shared_FC_soft = TimeDistributed(Dense(params['OUTPUT_VOCABULARY_SIZE'],
activation=params['CLASSIFIER_ACTIVATION'],
kernel_regularizer=l2(params['WEIGHT_DECAY']),
bias_regularizer=l2(params['WEIGHT_DECAY']),
name=params['CLASSIFIER_ACTIVATION']
),
name=self.ids_outputs[0])
softout = shared_FC_soft(out_layer)
self.model = Model(inputs=[src_text, next_words], outputs=softout)
if params['DOUBLE_STOCHASTIC_ATTENTION_REG'] > 0.:
self.model.add_loss(alpha_regularizer)
##################################################################
# SAMPLING MODEL #
##################################################################
# Now that we have the basic training model ready, let's prepare the model for applying decoding
# The beam-search model will include all the minimum required set of layers (decoder stage) which offer the
# possibility to generate the next state in the sequence given a pre-processed input (encoder stage)
# First, we need a model that outputs the preprocessed input + initial h state
# for applying the initial forward pass
model_init_input = [src_text, next_words]
model_init_output = [softout, annotations] + h_states_list
if 'LSTM' in params['DECODER_RNN_TYPE']:
model_init_output += h_memories_list
if self.return_alphas:
model_init_output.append(alphas)
self.model_init = Model(inputs=model_init_input, outputs=model_init_output)
# Store inputs and outputs names for model_init
self.ids_inputs_init = self.ids_inputs
ids_states_names = ['next_state_' + str(i) for i in range(len(h_states_list))]
# first output must be the output probs.
self.ids_outputs_init = self.ids_outputs + ['preprocessed_input'] + ids_states_names
if 'LSTM' in params['DECODER_RNN_TYPE']:
ids_memories_names = ['next_memory_' + str(i) for i in range(len(h_memories_list))]
self.ids_outputs_init += ids_memories_names
# Second, we need to build an additional model with the capability to have the following inputs:
# - preprocessed_input
# - prev_word
# - prev_state
# and the following outputs:
# - softmax probabilities
# - next_state
preprocessed_size = params['ENCODER_HIDDEN_SIZE'] * 2 if \
params['BIDIRECTIONAL_ENCODER'] \
else params['ENCODER_HIDDEN_SIZE']
# Define inputs
n_deep_decoder_layer_idx = 0
preprocessed_annotations = Input(name='preprocessed_input', shape=tuple([None, preprocessed_size]))
prev_h_states_list = [Input(name='prev_state_' + str(i),
shape=tuple([params['DECODER_HIDDEN_SIZE']]))
for i in range(len(h_states_list))]
input_attentional_decoder = [state_below, preprocessed_annotations,
prev_h_states_list[n_deep_decoder_layer_idx]]
if 'LSTM' in params['DECODER_RNN_TYPE']:
prev_h_memories_list = [Input(name='prev_memory_' + str(i),
shape=tuple([params['DECODER_HIDDEN_SIZE']]))
for i in range(len(h_memories_list))]
input_attentional_decoder.append(prev_h_memories_list[n_deep_decoder_layer_idx])
# Apply decoder
rnn_output = sharedAttRNNCond(input_attentional_decoder)
proj_h = rnn_output[0]
x_att = rnn_output[1]
alphas = rnn_output[2]
h_states_list = [rnn_output[3]]
if 'LSTM' in params['DECODER_RNN_TYPE']:
h_memories_list = [rnn_output[4]]
for reg in shared_reg_proj_h:
proj_h = reg(proj_h)
for (rnn_decoder_layer, proj_h_reg) in zip(shared_proj_h_list, shared_reg_proj_h_list):
n_deep_decoder_layer_idx += 1
input_rnn_decoder_layer = [proj_h, shared_Lambda_Permute(x_att),
prev_h_states_list[n_deep_decoder_layer_idx]]
if 'LSTM' in params['DECODER_RNN_TYPE']:
input_rnn_decoder_layer.append(prev_h_memories_list[n_deep_decoder_layer_idx])
current_rnn_output = rnn_decoder_layer(input_rnn_decoder_layer)
current_proj_h = current_rnn_output[0]
h_states_list.append(current_rnn_output[1]) # h_state
if 'LSTM' in params['DECODER_RNN_TYPE']:
h_memories_list.append(current_rnn_output[2]) # h_memory
for reg in proj_h_reg:
current_proj_h = reg(current_proj_h)
proj_h = Add()([proj_h, current_proj_h])
out_layer_mlp = shared_FC_mlp(proj_h)
out_layer_ctx = shared_FC_ctx(x_att)
out_layer_ctx = shared_Lambda_Permute(out_layer_ctx)
out_layer_emb = shared_FC_emb(state_below)
for (reg_out_layer_mlp, reg_out_layer_ctx, reg_out_layer_emb) in zip(shared_reg_out_layer_mlp,
shared_reg_out_layer_ctx,
shared_reg_out_layer_emb):
out_layer_mlp = reg_out_layer_mlp(out_layer_mlp)
out_layer_ctx = reg_out_layer_ctx(out_layer_ctx)
out_layer_emb = reg_out_layer_emb(out_layer_emb)
additional_output = shared_additional_output_merge([out_layer_mlp, out_layer_ctx, out_layer_emb])
out_layer = shared_activation_tanh(additional_output)
for (deep_out_layer, reg_list) in zip(shared_deep_list, shared_reg_deep_list):
out_layer = deep_out_layer(out_layer)
for reg in reg_list:
out_layer = reg(out_layer)
# Softmax
softout = shared_FC_soft(out_layer)
model_next_inputs = [next_words, preprocessed_annotations] + prev_h_states_list
model_next_outputs = [softout, preprocessed_annotations] + h_states_list
if 'LSTM' in params['DECODER_RNN_TYPE']:
model_next_inputs += prev_h_memories_list
model_next_outputs += h_memories_list
if self.return_alphas:
model_next_outputs.append(alphas)
self.model_next = Model(inputs=model_next_inputs,
outputs=model_next_outputs)
# Store inputs and outputs names for model_next
# first input must be previous word
self.ids_inputs_next = [self.ids_inputs[1]] + ['preprocessed_input']
# first output must be the output probs.
self.ids_outputs_next = self.ids_outputs + ['preprocessed_input']
# Input -> Output matchings from model_init to model_next and from model_next to model_next
self.matchings_init_to_next = {'preprocessed_input': 'preprocessed_input'}
self.matchings_next_to_next = {'preprocessed_input': 'preprocessed_input'}
# append all next states and matchings
for n_state in range(len(prev_h_states_list)):
self.ids_inputs_next.append('prev_state_' + str(n_state))
self.ids_outputs_next.append('next_state_' + str(n_state))
self.matchings_init_to_next['next_state_' + str(n_state)] = 'prev_state_' + str(n_state)
self.matchings_next_to_next['next_state_' + str(n_state)] = 'prev_state_' + str(n_state)
if 'LSTM' in params['DECODER_RNN_TYPE']:
for n_memory in range(len(prev_h_memories_list)):
self.ids_inputs_next.append('prev_memory_' + str(n_memory))
self.ids_outputs_next.append('next_memory_' + str(n_memory))
self.matchings_init_to_next['next_memory_' + str(n_memory)] = 'prev_memory_' + str(n_memory)
self.matchings_next_to_next['next_memory_' + str(n_memory)] = 'prev_memory_' + str(n_memory)