diff --git a/run_summarization.py b/run_summarization.py index b529719..6617d34 100644 --- a/run_summarization.py +++ b/run_summarization.py @@ -294,9 +294,9 @@ def main(unused_argv): # Make a namedtuple hps, containing the values of the hyperparameters that the model needs hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen'] hps_dict = {} - for key,val in FLAGS.__flags.iteritems(): # for each flag - if key in hparam_list: # if it's in the list - hps_dict[key] = val # add it to the dict + for key in hparam_list: # for each hyperparameter + if hasattr(FLAGS, key): # if it was given on the command line + hps_dict[key] = getattr(FLAGS, key) # add it to the dict hps = namedtuple("HParams", hps_dict.keys())(**hps_dict) # Create a batcher object that will create minibatches of data