Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@

# Created by https://www.gitignore.io/api/python,pycharm

### PyCharm ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff:
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/dictionaries

# Sensitive or high-churn files:
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.xml
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml

# Gradle:
.idea/**/gradle.xml
.idea/**/libraries

# CMake
cmake-build-debug/

# Mongo Explorer plugin:
.idea/**/mongoSettings.xml

## File-based project format:
*.iws

## Plugin-specific files:

# IntelliJ
/out/

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

### PyCharm Patch ###
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721

# *.iml
# modules.xml
# .idea/misc.xml
# *.ipr

# Sonarlint plugin
.idea/sonarlint

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# End of https://www.gitignore.io/api/python,pycharm

# Custom
.idea/
1 change: 1 addition & 0 deletions attention_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

"""This file defines the decoder"""

from builtins import str
import tensorflow as tf
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import array_ops
Expand Down
31 changes: 18 additions & 13 deletions batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
# ==============================================================================

"""This file contains code to process data into batches"""
from __future__ import absolute_import

import Queue
from future import standard_library
standard_library.install_aliases()
from builtins import range
from builtins import object
import queue
from random import shuffle
from threading import Thread
import time
Expand Down Expand Up @@ -199,7 +204,7 @@ def init_decoder_seq(self, example_list, hps):
for i, ex in enumerate(example_list):
self.dec_batch[i, :] = ex.dec_input[:]
self.target_batch[i, :] = ex.target[:]
for j in xrange(ex.dec_len):
for j in range(ex.dec_len):
self.padding_mask[i][j] = 1

def store_orig_strings(self, example_list):
Expand Down Expand Up @@ -229,8 +234,8 @@ def __init__(self, data_path, vocab, hps, single_pass):
self._single_pass = single_pass

# Initialize a queue of Batches waiting to be used, and a queue of Examples waiting to be batched
self._batch_queue = Queue.Queue(self.BATCH_QUEUE_MAX)
self._example_queue = Queue.Queue(self.BATCH_QUEUE_MAX * self._hps.batch_size)
self._batch_queue = queue.Queue(self.BATCH_QUEUE_MAX)
self._example_queue = queue.Queue(self.BATCH_QUEUE_MAX * self._hps.batch_size)

# Different settings depending on whether we're in single_pass mode or not
if single_pass:
Expand All @@ -245,12 +250,12 @@ def __init__(self, data_path, vocab, hps, single_pass):

# Start the threads that load the queues
self._example_q_threads = []
for _ in xrange(self._num_example_q_threads):
for _ in range(self._num_example_q_threads):
self._example_q_threads.append(Thread(target=self.fill_example_queue))
self._example_q_threads[-1].daemon = True
self._example_q_threads[-1].start()
self._batch_q_threads = []
for _ in xrange(self._num_batch_q_threads):
for _ in range(self._num_batch_q_threads):
self._batch_q_threads.append(Thread(target=self.fill_batch_queue))
self._batch_q_threads[-1].daemon = True
self._batch_q_threads[-1].start()
Expand Down Expand Up @@ -287,7 +292,7 @@ def fill_example_queue(self):

while True:
try:
(article, abstract) = input_gen.next() # read the next example from file. article and abstract are both strings.
(article, abstract) = next(input_gen) # read the next example from file. article and abstract are both strings.
except StopIteration: # if there are no more examples:
tf.logging.info("The example generator for this example queue filling thread has exhausted data.")
if self._single_pass:
Expand All @@ -311,13 +316,13 @@ def fill_batch_queue(self):
if self._hps.mode != 'decode':
# Get bucketing_cache_size-many batches of Examples into a list, then sort
inputs = []
for _ in xrange(self._hps.batch_size * self._bucketing_cache_size):
for _ in range(self._hps.batch_size * self._bucketing_cache_size):
inputs.append(self._example_queue.get())
inputs = sorted(inputs, key=lambda inp: inp.enc_len) # sort by length of encoder sequence

# Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue.
batches = []
for i in xrange(0, len(inputs), self._hps.batch_size):
for i in range(0, len(inputs), self._hps.batch_size):
batches.append(inputs[i:i + self._hps.batch_size])
if not self._single_pass:
shuffle(batches)
Expand All @@ -326,7 +331,7 @@ def fill_batch_queue(self):

else: # beam search decode mode
ex = self._example_queue.get()
b = [ex for _ in xrange(self._hps.batch_size)]
b = [ex for _ in range(self._hps.batch_size)]
self._batch_queue.put(Batch(b, self._hps, self._vocab))


Expand Down Expand Up @@ -356,10 +361,10 @@ def text_generator(self, example_generator):
Args:
example_generator: a generator of tf.Examples from file. See data.example_generator"""
while True:
e = example_generator.next() # e is a tf.Example
e = next(example_generator) # e is a tf.Example
try:
article_text = e.features.feature['article'].bytes_list.value[0] # the article text was saved under the key 'article' in the data files
abstract_text = e.features.feature['abstract'].bytes_list.value[0] # the abstract text was saved under the key 'abstract' in the data files
article_text = e.features.feature['article'].bytes_list.value[0].decode("utf-8") # the article text was saved under the key 'article' in the data files
abstract_text = e.features.feature['abstract'].bytes_list.value[0].decode("utf-8") # the abstract text was saved under the key 'abstract' in the data files
except ValueError:
tf.logging.error('Failed to get article or abstract from example: %s', text_format.MessageToString(e))
continue
Expand Down
15 changes: 10 additions & 5 deletions beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
# ==============================================================================

"""This file contains code to run beam search decoding"""
from __future__ import division
from __future__ import absolute_import

from builtins import range
from builtins import object
from past.utils import old_div
import tensorflow as tf
import numpy as np
import data
Expand Down Expand Up @@ -75,7 +80,7 @@ def log_prob(self):
@property
def avg_log_prob(self):
# normalize log probability by number of tokens (otherwise longer sequences always have lower probability)
return self.log_prob / len(self.tokens)
return old_div(self.log_prob, len(self.tokens))


def run_beam_search(sess, model, vocab, batch):
Expand All @@ -102,13 +107,13 @@ def run_beam_search(sess, model, vocab, batch):
attn_dists=[],
p_gens=[],
coverage=np.zeros([batch.enc_batch.shape[1]]) # zero vector of length attention_length
) for _ in xrange(FLAGS.beam_size)]
) for _ in range(FLAGS.beam_size)]
results = [] # this will contain finished hypotheses (those that have emitted the [STOP] token)

steps = 0
while steps < FLAGS.max_dec_steps and len(results) < FLAGS.beam_size:
latest_tokens = [h.latest_token for h in hyps] # latest token produced by each hypothesis
latest_tokens = [t if t in xrange(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
latest_tokens = [t if t in range(vocab.size()) else vocab.word2id(data.UNKNOWN_TOKEN) for t in latest_tokens] # change any in-article temporary OOV ids to [UNK] id, so that we can lookup word embeddings
states = [h.state for h in hyps] # list of current decoder states of the hypotheses
prev_coverage = [h.coverage for h in hyps] # list of coverage vectors (or None)

Expand All @@ -123,9 +128,9 @@ def run_beam_search(sess, model, vocab, batch):
# Extend each hypothesis and collect them all in all_hyps
all_hyps = []
num_orig_hyps = 1 if steps == 0 else len(hyps) # On the first step, we only had one original hypothesis (the initial hypothesis). On subsequent steps, all original hypotheses are distinct.
for i in xrange(num_orig_hyps):
for i in range(num_orig_hyps):
h, new_state, attn_dist, p_gen, new_coverage_i = hyps[i], new_states[i], attn_dists[i], p_gens[i], new_coverage[i] # take the ith hypothesis and new decoder state info
for j in xrange(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps:
for j in range(FLAGS.beam_size * 2): # for each of the top 2*beam_size hyps:
# Extend the ith hypothesis with the jth option
new_hyp = h.extend(token=topk_ids[i, j],
log_prob=topk_log_probs[i, j],
Expand Down
14 changes: 8 additions & 6 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# ==============================================================================

"""This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it"""
from __future__ import print_function

from builtins import range
import glob
import random
import struct
Expand Down Expand Up @@ -58,7 +60,7 @@ def __init__(self, vocab_file, max_size):
for line in vocab_f:
pieces = line.split()
if len(pieces) != 2:
print 'Warning: incorrectly formatted line in vocabulary file: %s\n' % line
print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line)
continue
w = pieces[0]
if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
Expand All @@ -69,10 +71,10 @@ def __init__(self, vocab_file, max_size):
self._id_to_word[self._count] = w
self._count += 1
if max_size != 0 and self._count >= max_size:
print "max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count)
print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
break

print "Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1])
print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]))

def word2id(self, word):
"""Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV."""
Expand All @@ -97,11 +99,11 @@ def write_metadata(self, fpath):
Args:
fpath: place to write the metadata file
"""
print "Writing word embedding metadata file to %s..." % (fpath)
print("Writing word embedding metadata file to %s..." % (fpath))
with open(fpath, "w") as f:
fieldnames = ['word']
writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
for i in xrange(self.size()):
for i in range(self.size()):
writer.writerow({"word": self._id_to_word[i]})


Expand Down Expand Up @@ -137,7 +139,7 @@ def example_generator(data_path, single_pass):
example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
yield example_pb2.Example.FromString(example_str)
if single_pass:
print "example_generator completed reading all datafiles. No more data."
print("example_generator completed reading all datafiles. No more data.")
break


Expand Down
Loading