Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ This will create a subdirectory of your specified `log_root` called `myexperimen

**Increasing sequence length during training**: Note that to obtain the results described in the paper, we increase the values of `max_enc_steps` and `max_dec_steps` in stages throughout training (mostly so we can perform quicker iterations during early stages of training). If you wish to do the same, start with small values of `max_enc_steps` and `max_dec_steps`, then interrupt and restart the job with larger values when you want to increase them.

#### Run training with flip
By default training is done with "teacher forcing",
instead of generating a new word and then feeding in that word as input when
generating the next word, the expected word in the actual headline is fed in.

However, during decoding the previously generated word is fed in when
generating the next word. That leads to a disconnect between training
and testing. To overcome this disconnect, during training
you can set a random fraction of the steps to be replaced with
the predicted word of the previous step. You can do this with `--flip=<prac>`

You can increase `flip` in a scheduled way. First train without any flip and then
increade flip to `0.2`
(https://arxiv.org/abs/1506.03099)

For debugging, if you want to see what are all the predicted words for all steps, run with `--mode=flip`

### Run (concurrent) eval
You may want to run a concurrent evaluation job, that runs your model on the validation set and logs the loss. To do this, run:

Expand All @@ -53,6 +70,14 @@ Additionally, the decode job produces a file called `attn_vis_data.json`. This f

If you want to run evaluation on the entire validation or test set and get ROUGE scores, set the flag `single_pass=1`. This will go through the entire dataset in order, writing the generated summaries to file, and then run evaluation using [pyrouge](https://pypi.python.org/pypi/pyrouge). (Note this will *not* produce the `attn_vis_data.json` files for the attention visualizer).

By default the beamsearch algorithm takes the best `--topk` results but instead you can
specificy that the `topk` result are randomly selected using `--temperature` parameter. (e.g. `0.8`)

You can request multiple results to be generated for each article with `--ntrials`.
You can force the different trials to be different using `--dbs_lambda` (e.g. `11`)
to add
a penality for having a beam with same token as another beam. (https://arxiv.org/pdf/1610.02424.pdf)

### Evaluate with ROUGE
`decode.py` uses the Python package [`pyrouge`](https://pypi.python.org/pypi/pyrouge) to run ROUGE evaluation. `pyrouge` provides an easier-to-use interface for the official Perl ROUGE package, which you must install for `pyrouge` to work. Here are some useful instructions on how to do this:
* [How to setup Perl ROUGE](http://kavita-ganesan.com/rouge-howto)
Expand Down
43 changes: 35 additions & 8 deletions beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tensorflow as tf
import numpy as np
import data
import copy

FLAGS = tf.app.flags.FLAGS

Expand Down Expand Up @@ -78,7 +79,7 @@ def avg_log_prob(self):
return self.log_prob / len(self.tokens)


def run_beam_search(sess, model, vocab, batch):
def run_beam_search(sess, model, vocab, batch, previous_best_hyps):
"""Performs beam search decoding on the given example.

Args:
Expand Down Expand Up @@ -125,20 +126,42 @@ def run_beam_search(sess, model, vocab, batch):
num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
for i in xrange(num_orig_hyps):
h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info
for j in xrange(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps:
n = len(h.tokens)
for j in xrange(FLAGS.topk): # for each of the top 2*beam_size hyps:
token = topk_ids[i, j]
score = topk_log_probs[i, j]
if FLAGS.dbs_lambda:
if token in [p.tokens[n] for p in previous_best_hyps if n < len(p.tokens)]:
score -= FLAGS.dbs_lambda
# Extend the ith hypothesis with the jth option
new_hyp = h.extend(token=topk_ids[i, j],
log_prob=topk_log_probs[i, j],
new_hyp = h.extend(token=token,
log_prob=score,
state=new_state,
attn_dist=attn_dist,
p_gen=p_gen,
coverage=new_coverage_i)
all_hyps.append(new_hyp)

temperature = model._hps.temperature
if temperature is None or temperature <= 0.:
all_hyps = sort_hyps(all_hyps)
else:
n = min(FLAGS.beam_size*2, len(all_hyps))
prb = np.exp(np.array([h.avg_log_prob for h in all_hyps]) / temperature)
res = []
for i in xrange(n):
z = np.sum(prb)
r = np.argmax(np.random.multinomial(1, prb / z, 1))
res.append(all_hyps[r])
prb[r] = 0. # make sure we select each element only once
all_hyps = res

# Filter and collect any hypotheses that have produced the end token.
hyps = [] # will contain hypotheses for the next step
for h in sort_hyps(all_hyps): # in order of most likely h
if h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached...
for h in all_hyps: # in order of most likely h
if h.latest_token == vocab.word2id(data.UNKNOWN_TOKEN): # skip UNKOWN
continue
elif h.latest_token == vocab.word2id(data.STOP_DECODING): # if stop token is reached...
# If this hypothesis is sufficiently long, put in results. Otherwise discard.
if steps >= FLAGS.min_dec_steps:
results.append(h)
Expand All @@ -147,13 +170,17 @@ def run_beam_search(sess, model, vocab, batch):
if len(hyps) == FLAGS.beam_size or len(results) == FLAGS.beam_size:
# Once we've collected beam_size-many hypotheses for the next step, or beam_size-many complete hypotheses, stop.
break
if len(hyps) == 0:
break
while len(hyps) < FLAGS.beam_size:
hyps.append(copy.copy(hyps[-1]))

steps += 1

# At this point, either we've got beam_size results, or we've reached maximum decoder steps

if len(results)==0: # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
results = hyps
if len(results)==0: # we don't have any complete result
return None

# Sort hypotheses by average log probability
hyps_sorted = sort_hyps(results)
Expand Down
56 changes: 34 additions & 22 deletions decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,28 +95,40 @@ def decode(self):
article_withunks = data.show_art_oovs(original_article, self._vocab) # string
abstract_withunks = data.show_abs_oovs(original_abstract, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None)) # string

# Run beam search to get best Hypothesis
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch)

# Extract the output ids from the hypothesis and convert back to words
output_ids = [int(t) for t in best_hyp.tokens[1:]]
decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string

if FLAGS.single_pass:
self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
counter += 1 # this is how many examples we've decoded
else:
print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

tf.logging.info('ARTICLE: %s', article_withunks)
tf.logging.info('REFERENCE SUMMARY: %s', abstract_withunks)

all_best_hyp = []
for trial in range(int(FLAGS.ntrials)):
# Run beam search to get best Hypothesis
best_hyp = None
while best_hyp is None:
best_hyp = beam_search.run_beam_search(self._sess, self._model, self._vocab, batch, all_best_hyp)
all_best_hyp.append(best_hyp)

# Extract the output ids from the hypothesis and convert back to words
output_ids = [int(t) for t in best_hyp.tokens[1:]]
decoded_words = data.outputids2words(output_ids, self._vocab, (batch.art_oovs[0] if FLAGS.pointer_gen else None))

# Remove the [STOP] token from decoded_words, if necessary
try:
fst_stop_idx = decoded_words.index(data.STOP_DECODING) # index of the (first) [STOP] symbol
decoded_words = decoded_words[:fst_stop_idx]
except ValueError:
decoded_words = decoded_words
decoded_output = ' '.join(decoded_words) # single string

if FLAGS.single_pass:
self.write_for_rouge(original_abstract_sents, decoded_words, counter) # write ref summary and decoded summary to file, to eval with pyrouge later
counter += 1 # this is how many examples we've decoded
break
else:
tf.logging.info('GENERATED SUMMARY: %.2f %s', np.mean(best_hyp.log_probs), decoded_output)
if int(FLAGS.ntrials) == 1:
# print_results(article_withunks, abstract_withunks, decoded_output) # log output to screen
self.write_for_attnvis(article_withunks, abstract_withunks, decoded_words, best_hyp.attn_dists, best_hyp.p_gens) # write info to .json file for visualization tool

if not FLAGS.single_pass:
# Check if SECS_UNTIL_NEW_CKPT has elapsed; if so return so we can load a new checkpoint
t1 = time.time()
if t1-t0 > SECS_UNTIL_NEW_CKPT:
Expand Down
44 changes: 34 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import tensorflow as tf
from attention_decoder import attention_decoder
from tensorflow.contrib.tensorboard.plugins import projector

EPS = 1e-8
FLAGS = tf.app.flags.FLAGS

class SummarizationModel(object):
Expand Down Expand Up @@ -240,12 +240,12 @@ def _add_seq2seq(self):
if FLAGS.pointer_gen:
final_dists = self._calc_final_dist(vocab_dists, self.attn_dists)
# Take log of final distribution
log_dists = [tf.log(dist) for dist in final_dists]
log_dists = [tf.log(dist + EPS) for dist in final_dists]
else: # just take log of vocab_dists
log_dists = [tf.log(dist) for dist in vocab_dists]
log_dists = [tf.log(dist + EPS) for dist in vocab_dists]


if hps.mode in ['train', 'eval']:
if hps.mode in ['train', 'eval', 'flip']:
# Calculate the loss
with tf.variable_scope('loss'):
if FLAGS.pointer_gen: # calculate loss from log_dists
Expand Down Expand Up @@ -279,8 +279,12 @@ def _add_seq2seq(self):
# We run decode beam search mode one decoder step at a time
assert len(log_dists)==1 # log_dists is a singleton list containing shape (batch_size, extended_vsize)
log_dists = log_dists[0]
self._topk_log_probs, self._topk_ids = tf.nn.top_k(log_dists, hps.batch_size*2) # note batch_size=beam_size in decode mode

self._topk_log_probs, self._topk_ids = tf.nn.top_k(log_dists, hps.topk) # note batch_size=beam_size in decode mode
if hps.mode == "flip" or (hps.mode == "train" and hps.flip):
# for flipping use only decoder output without pointer
if FLAGS.pointer_gen:
log_dists = [tf.log(dist + EPS) for dist in vocab_dists]
self._topk_log_probs, self._topk_ids = tf.nn.top_k(log_dists, hps.topk) # note batch_size=beam_size in decode mode

def _add_train_op(self):
"""Sets self._train_op, the op to run for training."""
Expand All @@ -290,15 +294,23 @@ def _add_train_op(self):
gradients = tf.gradients(loss_to_minimize, tvars, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE)

# Clip the gradients
with tf.device("/gpu:0"):
with tf.device("/%s"%self._hps.gpu):
grads, global_norm = tf.clip_by_global_norm(gradients, self._hps.max_grad_norm)

# Add a summary
tf.summary.scalar('global_norm', global_norm)

# Apply adagrad optimizer
optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc)
with tf.device("/gpu:0"):
if self._hps.optimizer == 'adagrad':
optimizer = tf.train.AdagradOptimizer(self._hps.lr, initial_accumulator_value=self._hps.adagrad_init_acc)
elif self._hps.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(self._hps.lr)
elif self._hps.optimizer == 'yellowfin':
from yellowfin import YFOptimizer
optimizer = YFOptimizer(lr_factor=self._hps.lr)
elif self._hps.optimizer == 'sgd':
optimizer = tf.train.MomentumOptimizer(self._hps.lr, 0.9)
with tf.device("/%s"%self._hps.gpu):
self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step, name='train_step')


Expand All @@ -307,7 +319,7 @@ def build_graph(self):
tf.logging.info('Building graph...')
t0 = time.time()
self._add_placeholders()
with tf.device("/gpu:0"):
with tf.device("/%s"%self._hps.gpu):
self._add_seq2seq()
self.global_step = tf.Variable(0, name='global_step', trainable=False)
if self._hps.mode == 'train':
Expand Down Expand Up @@ -341,6 +353,17 @@ def run_eval_step(self, sess, batch):
to_return['coverage_loss'] = self._coverage_loss
return sess.run(to_return, feed_dict)

def run_decode(self, sess, batch):
feed_dict = self._make_feed_dict(batch)
to_return = {
"ids": self._topk_ids,
"probs": self._topk_log_probs,
'summaries': self._summaries,
'loss': self._loss,
'global_step': self.global_step,
}
return sess.run(to_return, feed_dict)

def run_encoder(self, sess, batch):
"""For beam search decoding. Run the encoder on the batch and return the encoder states and decoder initial state.

Expand Down Expand Up @@ -381,6 +404,7 @@ def decode_onestep(self, sess, batch, latest_tokens, enc_states, dec_init_states
p_gens: Generation probabilities for this step. A list length beam_size. List of None if in baseline mode.
new_coverage: Coverage vectors for this step. A list of arrays. List of None if coverage is not turned on.
"""
hps = self._hps

beam_size = len(dec_init_states)

Expand Down
Loading