Skip to content

Commit

Permalink
pylint; fix a lot of things; test some of them;
Browse files Browse the repository at this point in the history
  • Loading branch information
qhduan committed Feb 18, 2018
1 parent f2a1850 commit 65bb89e
Show file tree
Hide file tree
Showing 16 changed files with 210 additions and 547 deletions.
7 changes: 0 additions & 7 deletions chatbot_ad/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,11 @@ def test(bidirectional, cell_type, depth,
sess.run(init)
model_pred.load(sess, save_path)


last = None

while True:
user_text = input('Input Chat Sentence:')
if user_text in ('exit', 'quit'):
exit(0)
x_test = list(user_text.lower())
# if last is not None and last:
# print(last)
# x_test = last + [WordSequence.PAD_TAG] + x_test
x_test = [x_test]
bar = batch_flow([x_test], [ws], 1)
x, xl = next(bar)
Expand All @@ -120,7 +114,6 @@ def test(bidirectional, cell_type, depth,
if pp == WordSequence.PAD_TAG:
break
p.append(pp)
last = p


def main():
Expand Down
15 changes: 10 additions & 5 deletions chatbot_ad/train_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import tensorflow as tf
from tqdm import tqdm
import jieba
# import jieba

sys.path.append('..')

Expand Down Expand Up @@ -113,7 +113,6 @@ def test(bidirectional, cell_type, depth,

for epoch in range(1, n_epoch + 1):
costs = []
lengths = []
bar = tqdm(range(steps), total=steps,
desc='epoch {}, loss=0.000000'.format(epoch))
for _ in bar:
Expand Down Expand Up @@ -144,7 +143,13 @@ def test(bidirectional, cell_type, depth,

costs.append(cost)
# lengths.append(np.mean(al))
bar.set_description('epoch {} loss={:.6f} rmean={:.4f} rmin={:.4f} rmax={:.4f} rmed={:.4f}'.format(
des = ('epoch {} ',
'loss={:.6f} ',
'rmean={:.4f} ',
'rmin={:.4f} ',
'rmax={:.4f} ',
'rmed={:.4f}')
bar.set_description(des.format(
epoch,
np.mean(costs),
np.mean(rewards),
Expand All @@ -158,11 +163,11 @@ def test(bidirectional, cell_type, depth,

def repeat_reward(arr):
"""重复越多,分数越低"""
arr = list(arr)
from collections import Counter
arr = list(arr)
counter = Counter(arr)
t = sum([i for i in counter.values() if i > 1])
return(max(0, 1 - t / len(counter)))
return max(0, 1 - t / len(counter))


def chinese_reward(text):
Expand Down
3 changes: 0 additions & 3 deletions chatbot_ad/train_ad_no_preload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
对SequenceToSequence模型进行基本的参数组合测试
"""

import sys
import random
import pickle

import numpy as np
import tensorflow as tf
from tqdm import tqdm
from train_ad import test


Expand Down
25 changes: 19 additions & 6 deletions chatbot_ad/train_discriminative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import numpy as np
import tensorflow as tf
import jieba
# import jieba
from tqdm import tqdm
from sklearn.utils import shuffle
# from sklearn.utils import shuffle

sys.path.append('..')

Expand Down Expand Up @@ -109,8 +109,9 @@ def test(bidirectional, cell_type, depth,
al = []
new_a = []
for aa in a:
for j in range(0, len(aa)):
if aa[j] == WordSequence.END:
j = 0
for j, aaj in enumerate(aa):
if aaj == WordSequence.END:
break
new_a.append(list(aa[:j]))
if j <= 0:
Expand All @@ -119,9 +120,21 @@ def test(bidirectional, cell_type, depth,

max_len = max((a.shape[1], y.shape[1]))
if a.shape[1] < max_len:
a = np.concatenate((a, np.ones((batch_size, max_len - a.shape[1])) * WordSequence.END), axis=1)
a = np.concatenate(
(
a,
np.ones(
(batch_size, max_len - a.shape[1])
) * WordSequence.END
), axis=1)
if y.shape[1] < max_len:
y = np.concatenate((y, np.ones((batch_size, max_len - y.shape[1])) * WordSequence.END), axis=1)
y = np.concatenate(
(
y,
np.ones(
(batch_size, max_len - y.shape[1])
) * WordSequence.END
), axis=1)

targets = np.array(([0] * len(a)) + ([1] * len(a)))

Expand Down
6 changes: 6 additions & 0 deletions chatbot_ad/train_tfidf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
训练一个tfidf
"""

import sys
import pickle
Expand All @@ -7,6 +10,9 @@
sys.path.append('..')

def main():
"""
训练一个tfidf
"""
x_data, _, _ = pickle.load(
open('chatbot.pkl', 'rb'))

Expand Down
53 changes: 45 additions & 8 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import random
import numpy as np
import tensorflow as tf
from tensorflow.python.client import device_lib

VOCAB_SIZE_THRESHOLD_CPU = 50000
Expand All @@ -26,19 +25,57 @@ def _get_embed_device(vocab_size):
return "/gpu:0"


def transform_sentence(q, ws_q, q_max):
x = ws_q.transform(q, max_len=q_max)
xl = len(q)
return x, xl
def transform_sentence(sentence, ws, max_len=None):
"""转换一个单独句子
Args:
sentence: 一句话,例如一个数组['你', '好', '吗']
ws: 一个WordSequence对象,转换器
max_len:
进行padding的长度,也就是如果sentence长度小于max_len
则padding到max_len这么长
Ret:
encoded:
一个经过ws转换的数组,例如[4, 5, 6, 3]
encoded_len: 上面的长度
"""
encoded = ws.transform(
sentence,
max_len=max_len if max_len is not None else len(sentence))
encoded_len = len(encoded)
return encoded, encoded_len


def batch_flow(data, ws, batch_size, raw=False):
"""从数据中随机 batch_size 个的数据,然后 yield 出去
Args:
data:
是一个数组,必须包含一个护着更多个同等的数据队列数组
ws:
可以是一个WordSequence对象,也可以是多个组成的数组
如果是多个,那么数组数量应该与data的数据数量保持一致,即len(data) == len(ws)
batch_size:
批量的大小
raw:
是否返回原始对象,如果为True,假设结果ret,那么len(ret) == len(data) * 3
如果为False,那么len(ret) == len(data) * 2
例如需要输入问题与答案的队列,问题队列Q = (q_1, q_2, q_3 ... q_n)
答案队列A = (a_1, a_2, a_3 ... a_n),有len(Q) == len(A)
ws是一个Q与A共用的WordSequence对象,
那么可以有: batch_flow([Q, A], ws, batch_size=32)
这样会返回一个generator,每次next(generator)会返回一个包含4个对象的数组,分别代表:
next(generator) == q_i_encoded, q_i_len, a_i_encoded, a_i_len
如果设置raw = True,则:
next(generator) == q_i_encoded, q_i_len, q_i, a_i_encoded, a_i_len, a_i
其中 q_i_encoded 相当于 ws.transform(q_i)
不过经过了batch修正,把一个batch中每个结果的长度,padding到了数组内最大的句子长度
"""

all_data = list(zip(*data))

if isinstance(ws, list) or isinstance(ws, tuple):
if isinstance(ws, (list, tuple)):
assert len(ws) == len(data), \
'len(ws) must equal to len(data) if ws is list or tuple'

Expand All @@ -55,9 +92,9 @@ def batch_flow(data, ws, batch_size, raw=False):
max_len = max([len(x[j]) for x in data_batch])
max_lens.append(max_len)

for i, d in enumerate(data_batch):
for d in data_batch:
for j in range(len(data)):
if isinstance(ws, list) or isinstance(ws, tuple):
if isinstance(ws, (list, tuple)):
w = ws[j]
else:
w = ws
Expand Down
Loading

0 comments on commit 65bb89e

Please sign in to comment.