Skip to content

Commit

Permalink
change width batch_flow change
Browse files Browse the repository at this point in the history
  • Loading branch information
qhduan committed Feb 17, 2018
1 parent 5db496a commit aec98a3
Show file tree
Hide file tree
Showing 23 changed files with 156 additions and 1,165 deletions.
2 changes: 2 additions & 0 deletions chatbot/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
*.zip
*.pkl
*.conv
*.ckpt*
checkpoint
4 changes: 2 additions & 2 deletions chatbot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ python3 extract_conv.py

## 4、训练数据

运行 `train.py` 训练(默认到`/tmp/s2ss_chatbot.ckpt`
运行 `python3 train.py` 训练(默认到`./s2ss_chatbot.ckpt`

## 5、测试数据(测试翻译)

运行 `test.py` 查看测试结果
运行 `python3 test.py` 查看测试结果
24 changes: 19 additions & 5 deletions chatbot/extract_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,22 @@

sys.path.append('..')


def make_split(line):
"""构造合并两个句子之间的符号
"""
if re.match(r'.*([,。…?!~\.,!?])$', ''.join(line)):
return []
return [',']

def main(limit=15):

def good_line(line):
if len(re.findall(r'[a-zA-Z]', ''.join(line))) > 5:
return False
return True


def main(limit=15, min_limit=5):
"""执行程序
Args:
limit: 只输出句子长度小于limit的句子
Expand Down Expand Up @@ -49,12 +57,18 @@ def main(limit=15):
last_line = None
if i > 0:
last_line = group[i - 1]
if not good_line(last_line):
last_line = None
next_line = None
if i < len(group) - 1:
next_line = group[i + 1]
if not good_line(next_line):
next_line = None
next_next_line = None
if i < len(group) - 2:
next_next_line = group[i + 2]
if not good_line(next_next_line):
next_next_line = None

if next_line:
x_data.append(line)
Expand All @@ -74,20 +88,20 @@ def main(limit=15):
print('-' * 20)

data = list(zip(x_data, y_data))
data = [(x, y) for x, y in data if len(x) < limit and len(y) < limit]
data = [
(x, y) for x, y in data
if len(x) < limit and len(y) < limit and len(y) >= min_limit]
x_data, y_data = zip(*data)

print('fit word_sequence')

ws_input = WordSequence()
ws_target = WordSequence()
ws_input.fit(x_data)
ws_target.fit(y_data)

print('dump')

pickle.dump(
(x_data, y_data, ws_input, ws_target),
(x_data, y_data, ws_input),
open('chatbot.pkl', 'wb')
)

Expand Down
111 changes: 8 additions & 103 deletions chatbot/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test(bidirectional, cell_type, depth,
from data_utils import batch_flow
from word_sequence import WordSequence # pylint: disable=unused-variable

x_data, _, ws_input, ws_target = pickle.load(open('chatbot.pkl', 'rb'))
x_data, _, ws = pickle.load(open('chatbot.pkl', 'rb'))

for x in x_data[:5]:
print(' '.join(x))
Expand All @@ -32,13 +32,13 @@ def test(bidirectional, cell_type, depth,
)

# save_path = '/tmp/s2ss_chatbot.ckpt'
save_path = '/tmp/s2ss_chatbot.ckpt'
save_path = './s2ss_chatbot.ckpt'

# 测试部分
tf.reset_default_graph()
model_pred = SequenceToSequence(
input_vocab_size=len(ws_input),
target_vocab_size=len(ws_target),
input_vocab_size=len(ws),
target_vocab_size=len(ws),
batch_size=1,
mode='decode',
beam_width=0,
Expand All @@ -63,8 +63,8 @@ def test(bidirectional, cell_type, depth,
if user_text in ('exit', 'quit'):
exit(0)
x_test = [list(user_text.lower())]
bar = batch_flow(x_test, x_test, ws_input, ws_target, 1)
x, xl, _, _ = next(bar)
bar = batch_flow([x_test], ws, 1)
x, xl = next(bar)
# x = np.array([
# list(reversed(xx))
# for xx in x
Expand All @@ -76,103 +76,8 @@ def test(bidirectional, cell_type, depth,
np.array(xl)
)
print(pred)
print(ws_input.inverse_transform(x[0]))
print(ws_target.inverse_transform(pred[0]))


# x_data, y_data, ws_input, ws_target = pickle.load(open('data.pkl', 'rb'))
#
# # 获取一些假数据
# # x_data, y_data, ws_input, ws_target = generate(size=10000)
#
# # 训练部分
#
# split = int(len(x_data) * 0.8)
# _, x_test, _, y_test = (
# x_data[:split], x_data[split:], y_data[:split], y_data[split:])
#
# config = tf.ConfigProto(
# device_count={'CPU': 1, 'GPU': 0},
# allow_soft_placement=True,
# log_device_placement=False
# )
#
# save_path = '/tmp/s2ss_chatbot/'
#
# # 测试部分
# tf.reset_default_graph()
# model_pred = SequenceToSequence(
# input_vocab_size=len(ws_input),
# target_vocab_size=len(ws_target),
# batch_size=1,
# mode='decode',
# beam_width=5,
# bidirectional=bidirectional,
# cell_type=cell_type,
# depth=depth,
# attention_type=attention_type,
# use_residual=use_residual,
# use_dropout=use_dropout,
# parallel_iterations=1,
# hidden_units=128 # for test
# )
# init = tf.global_variables_initializer()
#
# with tf.Session(config=config) as sess:
# sess.run(init)
# model_pred.load(sess, save_path)
#
# bar = batch_flow(x_test, y_test, ws_input, ws_target, 1)
# t = 0
# for x, xl, y, _ in bar:
# pred = model_pred.predict(
# sess,
# np.array(x),
# np.array(xl)
# )
# print(ws_input.inverse_transform(x[0]))
# print(ws_target.inverse_transform(y[0]))
# print(ws_target.inverse_transform(pred[0, :, 0]))
# t += 1
# if t >= 3:
# break
#
# tf.reset_default_graph()
# model_pred = SequenceToSequence(
# input_vocab_size=len(ws_input),
# target_vocab_size=len(ws_target),
# batch_size=1,
# mode='decode',
# beam_width=1,
# bidirectional=bidirectional,
# cell_type=cell_type,
# depth=depth,
# attention_type=attention_type,
# use_residual=use_residual,
# use_dropout=use_dropout,
# parallel_iterations=1,
# hidden_units=128 # for test
# )
# init = tf.global_variables_initializer()
#
# with tf.Session(config=config) as sess:
# sess.run(init)
# model_pred.load(sess, save_path)
#
# bar = batch_flow(x_test, y_test, ws_input, ws_target, 1)
# t = 0
# for x, xl, y, _ in bar:
# pred = model_pred.predict(
# sess,
# np.array(x),
# np.array(xl)
# )
# print(ws_input.inverse_transform(x[0]))
# print(ws_target.inverse_transform(y[0]))
# print(ws_target.inverse_transform(pred[0]))
# t += 1
# if t >= 3:
# break
print(ws.inverse_transform(x[0]))
print(ws.inverse_transform(pred[0]))


def main():
Expand Down
44 changes: 18 additions & 26 deletions chatbot/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,13 @@ def test(bidirectional, cell_type, depth,
"""测试不同参数在生成的假数据上的运行结果"""

from sequence_to_sequence import SequenceToSequence
from data_utils import batch_flow_bucket
from data_utils import batch_flow
from word_sequence import WordSequence # pylint: disable=unused-variable

x_data, y_data, ws_input, ws_target = pickle.load(
x_data, y_data, ws = pickle.load(
open('chatbot.pkl', 'rb'))

# 获取一些假数据
# x_data, y_data, ws_input, ws_target = generate(size=10000)

# 训练部分
split = int(len(x_data) * 0.8)
x_train, x_test, y_train, y_test = (
x_data[:split], x_data[split:], y_data[:split], y_data[split:])
n_epoch = 5
batch_size = 256
steps = int(len(x_train) / batch_size) + 1
Expand All @@ -41,7 +35,7 @@ def test(bidirectional, cell_type, depth,
log_device_placement=False
)

save_path = '/tmp/s2ss_chatbot.ckpt'
save_path = './s2ss_chatbot.ckpt'

tf.reset_default_graph()
with tf.Graph().as_default():
Expand All @@ -52,8 +46,8 @@ def test(bidirectional, cell_type, depth,
with tf.Session(config=config) as sess:

model = SequenceToSequence(
input_vocab_size=len(ws_input),
target_vocab_size=len(ws_target),
input_vocab_size=len(ws),
target_vocab_size=len(ws),
batch_size=batch_size,
learning_rate=0.001,
bidirectional=bidirectional,
Expand All @@ -73,9 +67,7 @@ def test(bidirectional, cell_type, depth,
# print(sess.run(model.input_layer.kernel))
# exit(1)

flow = batch_flow_bucket(
x_train, y_train, ws_input, ws_target, batch_size
)
flow = batch_flow([x_data, y_data], ws, batch_size)

for epoch in range(1, n_epoch + 1):
costs = []
Expand All @@ -98,8 +90,8 @@ def test(bidirectional, cell_type, depth,
# 测试部分
tf.reset_default_graph()
model_pred = SequenceToSequence(
input_vocab_size=len(ws_input),
target_vocab_size=len(ws_target),
input_vocab_size=len(ws),
target_vocab_size=len(ws),
batch_size=1,
mode='decode',
beam_width=12,
Expand All @@ -119,25 +111,25 @@ def test(bidirectional, cell_type, depth,
sess.run(init)
model_pred.load(sess, save_path)

bar = batch_flow_bucket(x_test, y_test, ws_input, ws_target, 1)
bar = batch_flow([x_data, y_data], ws, 1)
t = 0
for x, xl, y, yl in bar:
pred = model_pred.predict(
sess,
np.array(x),
np.array(xl)
)
print(ws_input.inverse_transform(x[0]))
print(ws_target.inverse_transform(y[0]))
print(ws_target.inverse_transform(pred[0]))
print(ws.inverse_transform(x[0]))
print(ws.inverse_transform(y[0]))
print(ws.inverse_transform(pred[0]))
t += 1
if t >= 3:
break

tf.reset_default_graph()
model_pred = SequenceToSequence(
input_vocab_size=len(ws_input),
target_vocab_size=len(ws_target),
input_vocab_size=len(ws),
target_vocab_size=len(ws),
batch_size=1,
mode='decode',
beam_width=1,
Expand All @@ -157,17 +149,17 @@ def test(bidirectional, cell_type, depth,
sess.run(init)
model_pred.load(sess, save_path)

bar = batch_flow_bucket(x_test, y_test, ws_input, ws_target, 1)
bar = batch_flow([x_data, y_data], ws, 1)
t = 0
for x, xl, y, yl in bar:
pred = model_pred.predict(
sess,
np.array(x),
np.array(xl)
)
print(ws_input.inverse_transform(x[0]))
print(ws_target.inverse_transform(y[0]))
print(ws_target.inverse_transform(pred[0]))
print(ws.inverse_transform(x[0]))
print(ws.inverse_transform(y[0]))
print(ws.inverse_transform(pred[0]))
t += 1
if t >= 3:
break
Expand Down
6 changes: 0 additions & 6 deletions chatbot_rl/.gitignore

This file was deleted.

Loading

0 comments on commit aec98a3

Please sign in to comment.