Skip to content

Commit

Permalink
fix test_atten.py; update readme;
Browse files Browse the repository at this point in the history
  • Loading branch information
qhduan committed Mar 9, 2018
1 parent 672ab7d commit 5013ec7
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 9 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ Input English Sentence:i'm really a bad boy

[Chatbot实例](/chatbot/)


`test_atten.py` 脚本,测试并展示 attention 的热力图


# TensorFlow alert

Test in
Expand Down
21 changes: 21 additions & 0 deletions chatbot_cut/read_vector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
"""
读取一个文本格式的,保存预训练好的embedding的文件
wiki.zh.vec
它的第一行会被忽略
第二行开始,每行是 词 + 空格 + 词向量维度0 + 空格 + 词向量维度1 + ...
参考fasttext的文本格式
https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md
"""

import pickle
import numpy as np
from tqdm import tqdm


def read_vector(path='wiki.zh.vec', output_path='word_vec.pkl'):
"""
读取文本文件 path 中的数据,并且生成一个 dict 写入到 output_path
格式:
word_vec = {
'word_1': np.array(vec_of_word_1),
'word_2': np.array(vec_of_word_2),
...
}
"""
fp = open(path, 'r')
word_vec = {}
first_skip = False
Expand Down
12 changes: 5 additions & 7 deletions sequence_to_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np
import tensorflow as tf
from tensorflow import layers
from tensorflow.python.util import nest
# from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops
from tensorflow.contrib import seq2seq
from tensorflow.contrib.seq2seq import BahdanauAttention
Expand All @@ -37,7 +37,7 @@
from tensorflow.contrib.rnn import MultiRNNCell
from tensorflow.contrib.rnn import DropoutWrapper
from tensorflow.contrib.rnn import ResidualWrapper
from tensorflow.contrib.rnn import LSTMStateTuple
# from tensorflow.contrib.rnn import LSTMStateTuple

from word_sequence import WordSequence
from data_utils import _get_embed_device
Expand Down Expand Up @@ -414,8 +414,6 @@ def build_encoder(self):
parallel_iterations=self.parallel_iterations,
swap_memory=True
)

return encoder_outputs, encoder_state
else:
# 双向 RNN 比较麻烦
encoder_cell_bw = self.build_encoder_cell()
Expand Down Expand Up @@ -443,7 +441,7 @@ def build_encoder(self):
encoder_state.append(encoder_bw_state[i])
encoder_state = tuple(encoder_state)

return encoder_outputs, encoder_state
return encoder_outputs, encoder_state


def build_decoder_cell(self, encoder_outputs, encoder_state):
Expand Down Expand Up @@ -499,7 +497,7 @@ def build_decoder_cell(self, encoder_outputs, encoder_state):

# 在非训练(预测)模式,并且没开启 beamsearch 的时候,打开 attention 历史信息
alignment_history = (
not self.mode == 'train' and not self.use_beamsearch_decode
self.mode != 'train' and not self.use_beamsearch_decode
)

def cell_input_fn(inputs, attention):
Expand Down Expand Up @@ -1091,7 +1089,7 @@ def predict(self, sess,

pred, atten = sess.run([
self.decoder_pred_decode,
self.final_state[1].alignment_history.stack()
self.final_state.alignment_history.stack()
], input_feed)

return pred, atten
Expand Down
3 changes: 1 addition & 2 deletions test_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test(bidirectional, cell_type, depth, attention_type):

# 训练部分

split = 9900
split = int(len(x_data) * 0.9)
x_train, x_test, y_train, y_test = (
x_data[:split], x_data[split:], y_data[:split], y_data[split:])
n_epoch = 2
Expand Down Expand Up @@ -86,7 +86,6 @@ def test(bidirectional, cell_type, depth, attention_type):
cell_type=cell_type,
depth=depth,
attention_type=attention_type,
alignment_history=True,
parallel_iterations=1
)
init = tf.global_variables_initializer()
Expand Down

0 comments on commit 5013ec7

Please sign in to comment.