diff --git a/README.md b/README.md index 6549739..bcc5848 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,23 @@ This will create a subdirectory of your specified `log_root` called `myexperimen **Increasing sequence length during training**: Note that to obtain the results described in the paper, we increase the values of `max_enc_steps` and `max_dec_steps` in stages throughout training (mostly so we can perform quicker iterations during early stages of training). If you wish to do the same, start with small values of `max_enc_steps` and `max_dec_steps`, then interrupt and restart the job with larger values when you want to increase them. +#### Run training with flip +By default training is done with "teacher forcing", +instead of generating a new word and then feeding in that word as input when +generating the next word, the expected word in the actual headline is fed in. + +However, during decoding the previously generated word is fed in when +generating the next word. That leads to a disconnect between training +and testing. To overcome this disconnect, during training +you can set a random fraction of the steps to be replaced with +the predicted word of the previous step. You can do this with `--flip=` + +You can increase `flip` in a scheduled way. First train without any flip and then +increade flip to `0.2` +(https://arxiv.org/abs/1506.03099) + +For debugging, if you want to see what are all the predicted words for all steps, run with `--mode=flip` + ### Run (concurrent) eval You may want to run a concurrent evaluation job, that runs your model on the validation set and logs the loss. To do this, run: @@ -53,6 +70,14 @@ Additionally, the decode job produces a file called `attn_vis_data.json`. This f If you want to run evaluation on the entire validation or test set and get ROUGE scores, set the flag `single_pass=1`. This will go through the entire dataset in order, writing the generated summaries to file, and then run evaluation using [pyrouge](https://pypi.python.org/pypi/pyrouge). (Note this will *not* produce the `attn_vis_data.json` files for the attention visualizer). +By default the beamsearch algorithm takes the best `--topk` results but instead you can +specificy that the `topk` result are randomly selected using `--temperature` parameter. (e.g. `0.8`) + +You can request multiple results to be generated for each article with `--ntrials`. +You can force the different trials to be different using `--dbs_lambda` (e.g. `11`) +to add +a penality for having a beam with same token as another beam. (https://arxiv.org/pdf/1610.02424.pdf) + ### Evaluate with ROUGE `decode.py` uses the Python package [`pyrouge`](https://pypi.python.org/pypi/pyrouge) to run ROUGE evaluation. `pyrouge` provides an easier-to-use interface for the official Perl ROUGE package, which you must install for `pyrouge` to work. Here are some useful instructions on how to do this: * [How to setup Perl ROUGE](http://kavita-ganesan.com/rouge-howto) diff --git a/beam_search.py b/beam_search.py index ff3e328..944ab22 100644 --- a/beam_search.py +++ b/beam_search.py @@ -19,6 +19,7 @@ import tensorflow as tf import numpy as np import data +import copy FLAGS = tf.app.flags.FLAGS @@ -78,7 +79,7 @@ def avg_log_prob(self): return self.log_prob / len(self.tokens) -def run_beam_search(sess, model, vocab, batch): +def run_beam_search(sess, model, vocab, batch, previous_best_hyps): """Performs beam search decoding on the given example. Args: @@ -125,20 +126,42 @@ def run_beam_search(sess, model, vocab, batch): num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct. for i in xrange(num_orig_hyps): h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info - for j in xrange(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps: + n = len(h.tokens) + for j in xrange(FLAGS.topk): # for each of the top 2*beam_size hyps: + token = topk_ids[i, j] + score = topk_log_probs[i, j] + if FLAGS.dbs_lambda: + if token in [p.tokens[n] for p in previous_best_hyps if n < len(p.tokens)]: + score -= FLAGS.dbs_lambda # Extend the ith hypothesis with the jth option - new_hyp = h.extend(token=topk_ids[i, j], - log_prob=topk_log_probs[i, j], + new_hyp = h.extend(token=token, + log_prob=score, state=new_state, attn_dist=attn_dist, p_gen=p_gen, coverage=new_coverage_i) all_hyps.append(new_hyp) + temperature = model._hps.temperature + if temperature is None or temperature <= 0.: + all_hyps = sort_hyps(all_hyps) + else: + n = min(FLAGS.beam_size*2, len(all_hyps)) + prb = np.exp(np.array([h.avg_log_prob for h in all_hyps]) / temperature) + res = [] + for i in xrange(n): + z = np.sum(prb) + r = np.argmax(np.random.multinomial(1, prb / z, 1)) + res.append(all_hyps[r]) + prb[r] = 0. # make sure we select each element only once + all_hyps = res + # Filter and collect any hypotheses that have produced the end token. hyps = [] # will contain hypotheses for the next step - for h in sort_hyps(all_hyps): # in order of most likely h - if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached... + for h in all_hyps: # in order of most likely h + if h.latest_token == vocab.word2id(data.UNKNOWN_TOKEN): # skip UNKOWN + continue + elif h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached... # If this hypothesis is sufficiently long, put in results. Otherwise discard. if steps >= FLAGS.min_dec_steps: results.append(h) @@ -147,13 +170,17 @@ def run_beam_search(sess, model, vocab, batch): if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size: # Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop. break + if len(hyps) == 0: + break + while len(hyps) < FLAGS.beam_size: + hyps.append(copy.copy(hyps[-1])) steps += 1 # At this point, either we've got beam_size results, or we've reached maximum decoder steps - if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results - results = hyps + if len(results)==0: # we don't have any complete result + return None # Sort hypotheses by average log probability hyps_sorted = sort_hyps(results) diff --git a/decode.py b/decode.py index 90b5aec..789ca23 100644 --- a/decode.py +++ b/decode.py @@ -95,28 +95,40 @@ def decode(self): article_withunks = data.show_art_oovs(original_article, self._vocab) # string abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string - # Run beam search to get best Hypothesis - best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch) - - # Extract the output ids from the hypothesis and convert back to words - output_ids = [int(t) for t in best_hyp.tokens[1:]] - decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) - - # Remove the [STOP] token from decoded_words, if necessary - try: - fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol - decoded_words = decoded_words[:fst_stop_idx] - except ValueError: - decoded_words = decoded_words - decoded_output = ' '.join(decoded_words) # single string - - if FLAGS.single_pass: - self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later - counter += 1 # this is how many examples we've decoded - else: - print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen - self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool - + tf.logging.info('ARTICLE: %s', article_withunks) + tf.logging.info('REFERENCE SUMMARY: %s', abstract_withunks) + + all_best_hyp = [] + for trial in range(int(FLAGS.ntrials)): + # Run beam search to get best Hypothesis + best_hyp = None + while best_hyp is None: + best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, all_best_hyp) + all_best_hyp.append(best_hyp) + + # Extract the output ids from the hypothesis and convert back to words + output_ids = [int(t) for t in best_hyp.tokens[1:]] + decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) + + # Remove the [STOP] token from decoded_words, if necessary + try: + fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol + decoded_words = decoded_words[:fst_stop_idx] + except ValueError: + decoded_words = decoded_words + decoded_output = ' '.join(decoded_words) # single string + + if FLAGS.single_pass: + self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later + counter += 1 # this is how many examples we've decoded + break + else: + tf.logging.info('GENERATED SUMMARY: %.2f %s', np.mean(best_hyp.log_probs), decoded_output) + if int(FLAGS.ntrials) == 1: + # print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen + self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool + + if not FLAGS.single_pass: # Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint t1 = time.time() if t1-t0 > SECS_UNTIL_NEW_CKPT: diff --git a/model.py b/model.py index 660848d..5b7c632 100644 --- a/model.py +++ b/model.py @@ -22,7 +22,7 @@ import tensorflow as tf from attention_decoder import attention_decoder from tensorflow.contrib.tensorboard.plugins import projector - +EPS = 1e-8 FLAGS = tf.app.flags.FLAGS class SummarizationModel(object): @@ -240,12 +240,12 @@ def _add_seq2seq(self): if FLAGS.pointer_gen: final_dists = self._calc_final_dist(vocab_dists, self.attn_dists) # Take log of final distribution - log_dists = [tf.log(dist) for dist in final_dists] + log_dists = [tf.log(dist + EPS) for dist in final_dists] else: # just take log of vocab_dists - log_dists = [tf.log(dist) for dist in vocab_dists] + log_dists = [tf.log(dist + EPS) for dist in vocab_dists] - if hps.mode in ['train', 'eval']: + if hps.mode in ['train', 'eval', 'flip']: # Calculate the loss with tf.variable_scope('loss'): if FLAGS.pointer_gen: # calculate loss from log_dists @@ -279,8 +279,12 @@ def _add_seq2seq(self): # We run decode beam search mode one decoder step at a time assert len(log_dists)==1 # log_dists is a singleton list containing shape (batch_size, extended_vsize) log_dists = log_dists[0] - self._topk_log_probs, self._topk_ids = tf.nn.top_k(log_dists, hps.batch_size*2) # note batch_size=beam_size in decode mode - + self._topk_log_probs, self._topk_ids = tf.nn.top_k(log_dists, hps.topk) # note batch_size=beam_size in decode mode + if hps.mode == "flip" or (hps.mode == "train" and hps.flip): + # for flipping use only decoder output without pointer + if FLAGS.pointer_gen: + log_dists = [tf.log(dist + EPS) for dist in vocab_dists] + self._topk_log_probs, self._topk_ids = tf.nn.top_k(log_dists, hps.topk) # note batch_size=beam_size in decode mode def _add_train_op(self): """Sets self._train_op, the op to run for training.""" @@ -290,15 +294,23 @@ def _add_train_op(self): gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE) # Clip the gradients - with tf.device("/gpu:0"): + with tf.device("/%s"%self._hps.gpu): grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm) # Add a summary tf.summary.scalar('global_norm', global_norm) # Apply adagrad optimizer - optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc) - with tf.device("/gpu:0"): + if self._hps.optimizer == 'adagrad': + optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc) + elif self._hps.optimizer == 'adam': + optimizer = tf.train.AdamOptimizer(self._hps.lr) + elif self._hps.optimizer == 'yellowfin': + from yellowfin import YFOptimizer + optimizer = YFOptimizer(lr_factor=self._hps.lr) + elif self._hps.optimizer == 'sgd': + optimizer = tf.train.MomentumOptimizer(self._hps.lr, 0.9) + with tf.device("/%s"%self._hps.gpu): self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step, name='train_step') @@ -307,7 +319,7 @@ def build_graph(self): tf.logging.info('Building graph...') t0 = time.time() self._add_placeholders() - with tf.device("/gpu:0"): + with tf.device("/%s"%self._hps.gpu): self._add_seq2seq() self.global_step = tf.Variable(0, name='global_step', trainable=False) if self._hps.mode == 'train': @@ -341,6 +353,17 @@ def run_eval_step(self, sess, batch): to_return['coverage_loss'] = self._coverage_loss return sess.run(to_return, feed_dict) + def run_decode(self, sess, batch): + feed_dict = self._make_feed_dict(batch) + to_return = { + "ids": self._topk_ids, + "probs": self._topk_log_probs, + 'summaries': self._summaries, + 'loss': self._loss, + 'global_step': self.global_step, + } + return sess.run(to_return, feed_dict) + def run_encoder(self, sess, batch): """For beam search decoding. Run the encoder on the batch and return the encoder states and decoder initial state. @@ -381,6 +404,7 @@ def decode_onestep(self, sess, batch, latest_tokens, enc_states, dec_init_states p_gens: Generation probabilities for this step. A list length beam_size. List of None if in baseline mode. new_coverage: Coverage vectors for this step. A list of arrays. List of None if coverage is not turned on. """ + hps = self._hps beam_size = len(dec_init_states) diff --git a/run_summarization.py b/run_summarization.py index 8639af9..1286ff8 100644 --- a/run_summarization.py +++ b/run_summarization.py @@ -27,6 +27,9 @@ from model import SummarizationModel from decode import BeamSearchDecoder import util +import data +import random +from tensorflow.python.ops import variables FLAGS = tf.app.flags.FLAGS @@ -41,6 +44,7 @@ # Where to save output tf.app.flags.DEFINE_string('log_root', '', 'Root directory for all logging.') tf.app.flags.DEFINE_string('exp_name', '', 'Name for experiment. Logs will be saved in a directory with this name, under log_root.') +tf.app.flags.DEFINE_string('pre_path', None, 'Full path to a previous experiment to start from it can have a different optimization method. e.g. //train/model.ckpt-') # Hyperparameters tf.app.flags.DEFINE_integer('hidden_dim', 256, 'dimension of RNN hidden states') @@ -64,6 +68,13 @@ tf.app.flags.DEFINE_boolean('coverage', False, 'Use coverage mechanism. Note, the experiments reported in the ACL paper train WITHOUT coverage until converged, and then train for a short phase WITH coverage afterwards. i.e. to reproduce the results in the ACL paper, turn this off for most of training then turn on for a short phase at the end.') tf.app.flags.DEFINE_float('cov_loss_wt', 1.0, 'Weight of coverage loss (lambda in the paper). If zero, then no incentive to minimize coverage loss.') tf.app.flags.DEFINE_boolean('convert_to_coverage_model', False, 'Convert a non-coverage model to a coverage model. Turn this on and run in train mode. Your current model will be copied to a new version (same name with _cov_init appended) that will be ready to run with coverage flag turned on, for the coverage training stage.') +tf.app.flags.DEFINE_float('temperature', None, 'When decoding, Beam search temperature. If None take top result otherwise randomly draw from topk=100 results. Try 0.1') +tf.app.flags.DEFINE_float('ntrials', 1, 'How many decoding to perform') +tf.app.flags.DEFINE_float('topk', None, 'When decoding, How many results to give from the model') +tf.app.flags.DEFINE_float('dbs_lambda', None, 'When ntrials>1, Penality for having a beam with same token as another beam. Try 2') +tf.app.flags.DEFINE_float('flip', None, 'When training, what part of the decoder input should be flipped with decoder output from previous step') +tf.app.flags.DEFINE_string('optimizer', 'adagrad', 'Which optimization method to use: adagrad (default), adam (try lr=2e-4), yellowfin (lr is YF\'s lr_decay, try lr=1), sgd (with momentum set to 0.9)') +tf.app.flags.DEFINE_string('gpu', 'gpu:0', 'Which GPU to use') def calc_running_avg_loss(loss, running_avg_loss, summary_writer, step, decay=0.99): @@ -122,13 +133,26 @@ def setup_training(model, batcher): train_dir = os.path.join(FLAGS.log_root, "train") if not os.path.exists(train_dir): os.makedirs(train_dir) - default_device = tf.device('/cpu:0') + default_device = tf.device('/%s'%FLAGS.gpu) with default_device: model.build_graph() # build the graph if FLAGS.convert_to_coverage_model: assert FLAGS.coverage, "To convert your non-coverage model to a coverage model, run with convert_to_coverage_model=True and coverage=True" convert_to_coverage_model() saver = tf.train.Saver(max_to_keep=1) # only keep 1 checkpoint at a time + if FLAGS.pre_path is not None: + # https://www.tensorflow.org/programmers_guide/supervisor + # remove variables that belong to optimization + var_list = variables._all_saveable_objects() + var_list = [v for v in var_list + if not any(v.op.name.endswith(s) + for s in ['Momentum', + 'YF_lr', 'YF_lr_factor', 'YF_mu', "YF_clip_thresh"])] # eg "seq2seq/embedding/embedding/Momentum" + pre_train_saver = tf.train.Saver(var_list) + def load_pretrain(sess): + pre_train_saver.restore(sess, FLAGS.pre_path) + else: + load_pretrain = None sv = tf.train.Supervisor(logdir=train_dir, is_chief=True, @@ -136,7 +160,8 @@ def setup_training(model, batcher): summary_op=None, save_summaries_secs=60, # save summaries for tensorboard every 60 secs save_model_secs=60, # checkpoint every 60 secs - global_step=model.global_step) + global_step=model.global_step, + init_fn=load_pretrain) summary_writer = sv.summary_writer tf.logging.info("Preparing or waiting for session...") sess_context_manager = sv.prepare_or_wait_for_session(config=util.get_config()) @@ -155,6 +180,33 @@ def run_training(model, batcher, sess_context_manager, sv, summary_writer): while True: # repeats until interrupted batch = batcher.next_batch() + if FLAGS.flip: + t0 = time.time() + results = model.run_decode(sess, batch) + t1 = time.time() + tf.logging.info('seconds for batch flip: %.2f', t1 - t0) + stop_id = model._vocab.word2id(data.STOP_DECODING) + start_id = model._vocab.word2id(data.START_DECODING) + pad_id = model._vocab.word2id(data.PAD_TOKEN) + + for b in range(len(batch.dec_batch)): + fst_stop_idx = np.where(batch.target_batch[b, :] == stop_id)[0] + if len(fst_stop_idx): + fst_stop_idx = fst_stop_idx[0] + else: + fst_stop_idx = len(batch.target_batch[b, :])-1 + + output_ids = [int(x) for x in results['ids'][:fst_stop_idx, b, 0]] + nflips = int(fst_stop_idx * FLAGS.flip + 0.5) + + flips = sorted(random.sample(xrange(fst_stop_idx), nflips)) + for input_idx in flips: + if output_ids[input_idx] in [stop_id, start_id, pad_id]: + continue + if batch.dec_batch[b, input_idx+1] in [stop_id, start_id, pad_id]: + continue + batch.dec_batch[b, input_idx+1] = output_ids[input_idx] + tf.logging.info('running training step...') t0=time.time() results = model.run_train_step(sess, batch) @@ -175,6 +227,45 @@ def run_training(model, batcher, sess_context_manager, sv, summary_writer): if train_step % 100 == 0: # flush the summary writer every so often summary_writer.flush() +def run_flip(model, batcher, vocab): + """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" + model.build_graph() # build the graph + saver = tf.train.Saver(max_to_keep=3) # we will keep 3 best checkpoints at a time + sess = tf.Session(config=util.get_config()) + + while True: + _ = util.load_ckpt(saver, sess) # load a new checkpoint + batch = batcher.next_batch() # get the next batch + + # run eval on the batch + t0=time.time() + # results = model.run_eval_step(sess, batch) + results = model.run_decode(sess, batch) + t1=time.time() + tf.logging.info('seconds for batch: %.2f', t1-t0) + # output_ids = [int(t) for t in best_hyp.tokens[1:]] + # decoded_words = data.outputids2words(output_ids, self._vocab, ( + # batch.art_oovs[0] if FLAGS.pointer_gen else None)) + stop_id = model._vocab.word2id('[STOP]') + + for b in range(len(batch.original_abstracts)): + original_abstract = batch.original_abstracts[b] # string + fst_stop_idx = np.where(batch.target_batch[b,:] == stop_id)[0] + if len(fst_stop_idx): + fst_stop_idx = fst_stop_idx[0] + else: + fst_stop_idx = len(batch.target_batch[b,:]) + + abstract_withunks = data.show_abs_oovs(original_abstract, model._vocab, None) # string + tf.logging.info('REFERENCE SUMMARY: %s', abstract_withunks) + + output_ids = [int(x) for x in results['ids'][:, b, 0]] + output_ids = output_ids[:fst_stop_idx] + decoded_words = data.outputids2words(output_ids, model._vocab, None) + decoded_output = ' '.join(decoded_words) + tf.logging.info('GENERATED SUMMARY: %s', decoded_output) + loss = results['loss'] + tf.logging.info('loss %.2f max flip %d dec %d target %d', loss, max(output_ids), batch.dec_batch.max(), batch.target_batch.max()) def run_eval(model, batcher, vocab): """Repeatedly runs eval iterations, logging to screen and writing summaries. Saves the model with the best loss seen so far.""" @@ -246,13 +337,21 @@ def main(unused_argv): # On each step, we have beam_size-many hypotheses in the beam, so we need to make a batch of these hypotheses. if FLAGS.mode == 'decode': FLAGS.batch_size = FLAGS.beam_size + if FLAGS.topk is None and FLAGS.temperature is not None: + FLAGS.topk = 100 # TODO this is a hacky and slow solution to the problem + else: + FLAGS.topk = FLAGS.batch_size*2 + elif FLAGS.mode == 'flip' or FLAGS.flip: + FLAGS.topk = 1 + elif FLAGS.topk is not None: + FLAGS.topk = int(FLAGS.topk) # If single_pass=True, check we're in decode mode if FLAGS.single_pass and FLAGS.mode!='decode': raise Exception("The single_pass flag should only be True in decode mode") # Make a namedtuple hps, containing the values of the hyperparameters that the model needs - hparam_list = ['mode', 'lr', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen'] + hparam_list = ['gpu', 'mode', 'lr', 'optimizer', 'adagrad_init_acc', 'rand_unif_init_mag', 'trunc_norm_init_std', 'max_grad_norm', 'hidden_dim', 'emb_dim', 'batch_size', 'max_dec_steps', 'max_enc_steps', 'coverage', 'cov_loss_wt', 'pointer_gen', 'temperature', 'topk', 'flip'] hps_dict = {} for key,val in FLAGS.__flags.iteritems(): # for each flag if key in hparam_list: # if it's in the list @@ -271,6 +370,9 @@ def main(unused_argv): elif hps.mode == 'eval': model = SummarizationModel(hps, vocab) run_eval(model, batcher, vocab) + elif hps.mode == 'flip': + model = SummarizationModel(hps, vocab) + run_flip(model, batcher, vocab) elif hps.mode == 'decode': decode_model_hps = hps # This will be the hyperparameters for the decoder model decode_model_hps = hps._replace(max_dec_steps=1) # The model is configured with max_dec_steps=1 because we only ever run one step of the decoder at a time (to do beam search). Note that the batcher is initialized with max_dec_steps equal to e.g. 100 because the batches need to contain the full summaries diff --git a/yellowfin.py b/yellowfin.py new file mode 100644 index 0000000..1aec4f3 --- /dev/null +++ b/yellowfin.py @@ -0,0 +1,258 @@ +import numpy as np +from math import ceil, floor +import tensorflow as tf +from tensorflow.python.training import momentum +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.framework import ops + +# Values for gate_gradients. +GATE_NONE = 0 +GATE_OP = 1 +GATE_GRAPH = 2 + +class YFOptimizer(object): + def __init__(self, lr=1.0, mu=0.0, clip_thresh=None, beta=0.999, curv_win_width=20, + mu_update_interval=1, zero_debias=True, delta_mu=0.0, lr_factor=1.): + ''' + clip thresh is the threshold value on ||lr * gradient|| + delta_mu can be place holder/variable/python scalar. They are used for additional + momentum in situations such as asynchronous-parallel training. The default is 0.0 + for basic usage of the optimizer. + Args: + lr: python scalar. The initial value of learning rate, we use 1.0 in our paper. + mu: python scalar. The initial value of momentum, we use 0.0 in our paper. + clip_thresh: python scalar. The cliping threshold for tf.clip_by_global_norm. + if None, no clipping will be carried out. + beta: python scalar. The smoothing parameter for estimations. + delta_mu: for extensions. Not necessary in the basic use. + Other features: + If you want to manually control the learning rates, self.lr_factor is + an interface to the outside, it is an multiplier for the internal learning rate + in YellowFin. It is helpful when you want to do additional hand tuning + or some decaying scheme to the tuned learning rate in YellowFin. + Example on using lr_factor can be found here: + https://github.com/JianGoForIt/YellowFin/blob/master/char-rnn-tensorflow/train_YF.py#L140 + ''' + self._lr = lr + self._mu = mu + + self._lr_var = tf.Variable(lr, dtype=tf.float32, name="YF_lr", trainable=False) + self._mu_var = tf.Variable(mu, dtype=tf.float32, name="YF_mu", trainable=False) + # for step scheme or decaying scheme for the learning rates + self.lr_factor = tf.Variable(lr_factor, dtype=tf.float32, name="YF_lr_factor", trainable=False) + if clip_thresh is not None: + self._clip_thresh_var = tf.Variable(clip_thresh, dtype=tf.float32, name="YF_clip_thresh", trainable=False) + else: + self._clip_thresh_var = None + + # the underlying momentum optimizer + self._optimizer = \ + tf.train.MomentumOptimizer(self._lr_var * self.lr_factor, self._mu_var + delta_mu) + + # moving average for statistics + self._beta = beta + self._moving_averager = None + + # for global step counting + # self._global_step = tf.Variable(0, trainable=False) + + # for conditional tuning + # self._do_tune = tf.greater(self._global_step, tf.constant(0) ) + + self._zero_debias = zero_debias + + self._tvars = None + + # for curvature range + self._curv_win_width = curv_win_width + self._curv_win = None + + + def curvature_range(self): + # set up the curvature window + self._curv_win = \ + tf.Variable(np.zeros( [self._curv_win_width, ] ), dtype=tf.float32, name="curv_win", trainable=False) + self._curv_win = tf.scatter_update(self._curv_win, + self._global_step % self._curv_win_width, self._grad_norm_squared) + # note here the iterations start from iteration 0 + valid_window = tf.slice(self._curv_win, tf.constant( [0, ] ), + tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width), self._global_step + 1), dim=0) ) + self._h_min_t = tf.reduce_min(valid_window) + self._h_max_t = tf.reduce_max(valid_window) + + curv_range_ops = [] + with tf.control_dependencies([self._h_min_t, self._h_max_t] ): + avg_op = self._moving_averager.apply([self._h_min_t, self._h_max_t] ) + with tf.control_dependencies([avg_op] ): + self._h_min = tf.identity(self._moving_averager.average(self._h_min_t) ) + self._h_max = tf.identity(self._moving_averager.average(self._h_max_t) ) + curv_range_ops.append(avg_op) + return curv_range_ops + + + def grad_variance(self): + grad_var_ops = [] + tensor_to_avg = [] + for t, g in zip(self._tvars, self._grads): + if isinstance(g, ops.IndexedSlices): + tensor_to_avg.append(tf.reshape(tf.unsorted_segment_sum(g.values, g.indices, g.dense_shape[0] ), shape=t.get_shape() ) ) + else: + tensor_to_avg.append(g) + avg_op = self._moving_averager.apply(tensor_to_avg) + grad_var_ops.append(avg_op) + with tf.control_dependencies([avg_op] ): + self._grad_avg = [self._moving_averager.average(val) for val in tensor_to_avg] + self._grad_avg_squared = [tf.square(val) for val in self._grad_avg] + self._grad_var = self._grad_norm_squared_avg - tf.add_n( [tf.reduce_sum(val) for val in self._grad_avg_squared] ) + return grad_var_ops + + + def dist_to_opt(self): + dist_to_opt_ops = [] + # running average of the norm of gradeint + self._grad_norm = tf.sqrt(self._grad_norm_squared) + avg_op = self._moving_averager.apply([self._grad_norm,] ) + dist_to_opt_ops.append(avg_op) + with tf.control_dependencies([avg_op] ): + self._grad_norm_avg = self._moving_averager.average(self._grad_norm) + # single iteration distance estimation, note here self._grad_norm_avg is per variable + self._dist_to_opt = self._grad_norm_avg / self._grad_norm_squared_avg + # running average of distance + avg_op = self._moving_averager.apply([self._dist_to_opt] ) + dist_to_opt_ops.append(avg_op) + with tf.control_dependencies([avg_op]): + self._dist_to_opt_avg = tf.identity(self._moving_averager.average(self._dist_to_opt) ) + return dist_to_opt_ops + + + def after_apply(self): + self._moving_averager = tf.train.ExponentialMovingAverage(decay=self._beta, zero_debias=self._zero_debias) + assert self._grads != None and len(self._grads) > 0 + after_apply_ops = [] + + # get per var g**2 and norm**2 + self._grad_squared = [] + self._grad_norm_squared = [] + for v, g in zip(self._tvars, self._grads): + with ops.colocate_with(v): + self._grad_squared.append(tf.square(g) ) + self._grad_norm_squared = [tf.reduce_sum(grad_squared) for grad_squared in self._grad_squared] + + # the following running average on squared norm of gradient is shared by grad_var and dist_to_opt + avg_op = self._moving_averager.apply(self._grad_norm_squared) + with tf.control_dependencies([avg_op] ): + self._grad_norm_squared_avg = [self._moving_averager.average(val) for val in self._grad_norm_squared] + self._grad_norm_squared = tf.add_n(self._grad_norm_squared) + self._grad_norm_squared_avg = tf.add_n(self._grad_norm_squared_avg) + after_apply_ops.append(avg_op) + + with tf.control_dependencies([avg_op] ): + curv_range_ops = self.curvature_range() + after_apply_ops += curv_range_ops + grad_var_ops = self.grad_variance() + after_apply_ops += grad_var_ops + dist_to_opt_ops = self.dist_to_opt() + after_apply_ops += dist_to_opt_ops + + return tf.group(*after_apply_ops) + + + def get_lr_tensor(self): + lr = (1.0 - tf.sqrt(self._mu) )**2 / self._h_min + return lr + + + def get_mu_tensor(self): + const_fact = self._dist_to_opt_avg**2 * self._h_min**2 / 2 / self._grad_var + coef = tf.Variable([-1.0, 3.0, 0.0, 1.0], dtype=tf.float32, name="cubic_solver_coef") + coef = tf.scatter_update(coef, tf.constant(2), -(3 + const_fact) ) + roots = tf.py_func(np.roots, [coef], Tout=tf.complex64, stateful=False) + + # filter out the correct root + root_idx = tf.logical_and(tf.logical_and(tf.greater(tf.real(roots), tf.constant(0.0) ), + tf.less(tf.real(roots), tf.constant(1.0) ) ), tf.less(tf.abs(tf.imag(roots) ), 1e-5) ) + # in case there are two duplicated roots satisfying the above condition + root = tf.reshape(tf.gather(tf.gather(roots, tf.where(root_idx) ), tf.constant(0) ), shape=[] ) + tf.assert_equal(tf.size(root), tf.constant(1) ) + + dr = self._h_max / self._h_min + mu = tf.maximum(tf.real(root)**2, ( (tf.sqrt(dr) - 1)/(tf.sqrt(dr) + 1) )**2) + return mu + + + def update_hyper_param(self): + assign_hyper_ops = [] + self._mu = tf.identity(tf.cond(self._do_tune, lambda: self.get_mu_tensor(), + lambda: self._mu_var) ) + with tf.control_dependencies([self._mu] ): + self._lr = tf.identity(tf.cond(self._do_tune, lambda: self.get_lr_tensor(), + lambda: self._lr_var) ) + + with tf.control_dependencies([self._mu, self._lr] ): + self._mu = self._beta * self._mu_var + (1 - self._beta) * self._mu + self._lr = self._beta * self._lr_var + (1 - self._beta) * self._lr + assign_hyper_ops.append(tf.assign(self._mu_var, self._mu) ) + assign_hyper_ops.append(tf.assign(self._lr_var, self._lr) ) + assign_hyper_op = tf.group(*assign_hyper_ops) + return assign_hyper_op + + + def apply_gradients(self, grads_tvars, global_step, name): + self._global_step = global_step + self._do_tune = tf.greater(self._global_step, tf.constant(0)) + self._grads, self._tvars = zip(*grads_tvars) + + with tf.variable_scope("apply_updates"): + if self._clip_thresh_var is not None: + self._grads_clip, self._grads_norm = tf.clip_by_global_norm(self._grads, self._clip_thresh_var) + apply_grad_op = \ + self._optimizer.apply_gradients(zip(self._grads_clip, self._tvars) ) + else: + apply_grad_op = \ + self._optimizer.apply_gradients(zip(self._grads, self._tvars) ) + + + with tf.variable_scope("after_apply"): + after_apply_op = self.after_apply() + + with tf.variable_scope("update_hyper"): + with tf.control_dependencies( [after_apply_op] ): + update_hyper_op = self.update_hyper_param() + + with tf.control_dependencies([update_hyper_op] ): + self._increment_global_step_op = tf.assign(self._global_step, self._global_step + 1) + + return tf.group(apply_grad_op, after_apply_op, update_hyper_op, self._increment_global_step_op) + + + def minimize(self, loss, global_step=None, var_list=None, + gate_gradients=GATE_OP, aggregation_method=None, + colocate_gradients_with_ops=False, name=None, + grad_loss=None): + """Adapted from Tensorflow Optimizer base class member function: + Add operations to minimize `loss` by updating `var_list`. + This method simply combines calls `compute_gradients()` and + `apply_gradients()`. If you want to process the gradient before applying + them call `tf.gradients()` and `self.apply_gradients()` explicitly instead + of using this function. + """ + grads_and_vars = self._optimizer.compute_gradients( + loss, var_list=var_list, gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + + vars_with_grad = [v for g, v in grads_and_vars if g is not None] + if not vars_with_grad: + raise ValueError( + "No gradients provided for any variable, check your graph for ops" + " that do not support gradients, between variables %s and loss %s." % + ([str(v) for _, v in grads_and_vars], loss)) + for g, v in grads_and_vars: + print "g ", g + print "v ", v + + return self.apply_gradients(grads_and_vars) + + \ No newline at end of file