Skip to content

Commit

Permalink
修改chatbot训练部分
Browse files Browse the repository at this point in the history
  • Loading branch information
qhduan committed Mar 23, 2018
1 parent 3c15f4b commit 1210ee3
Show file tree
Hide file tree
Showing 14 changed files with 72 additions and 1,279 deletions.
17 changes: 9 additions & 8 deletions chatbot/extract_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def main(limit=20, x_limit=3, y_limit=6):
if next_line:
x_data.append(line)
y_data.append(next_line)
# if last_line and next_line:
# x_data.append(last_line + make_split(last_line) + line)
# y_data.append(next_line)
# if next_line and next_next_line:
# x_data.append(line)
# y_data.append(next_line + make_split(next_line) \
# + next_next_line)
if last_line and next_line:
x_data.append(last_line + make_split(last_line) + line)
y_data.append(next_line)
if next_line and next_next_line:
x_data.append(line)
y_data.append(next_line + make_split(next_line) \
+ next_next_line)

print(len(x_data), len(y_data))
for ask, answer in zip(x_data[:20], y_data[:20]):
Expand All @@ -119,9 +119,10 @@ def main(limit=20, x_limit=3, y_limit=6):
print('dump')

pickle.dump(
(x_data, y_data, ws_input),
(x_data, y_data),
open('chatbot.pkl', 'wb')
)
pickle.dump(ws_input, open('ws.pkl', 'wb'))

print('done')

Expand Down
64 changes: 0 additions & 64 deletions chatbot/gen_same_person.py

This file was deleted.

13 changes: 13 additions & 0 deletions chatbot/params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"bidirectional": true,
"use_residual": false,
"use_dropout": false,
"time_major": false,
"cell_type": "lstm",
"depth": 2,
"attention_type": "Bahdanau",
"hidden_units": 1024,
"optimizer": "adam",
"learning_rate": 0.001,
"embedding_size": 300
}
34 changes: 7 additions & 27 deletions chatbot/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
sys.path.append('..')


def test(bidirectional, cell_type, depth,
attention_type, use_residual, use_dropout, time_major, hidden_units):
def test(params):
"""测试不同参数在生成的假数据上的运行结果"""

from sequence_to_sequence import SequenceToSequence
from data_utils import batch_flow
from word_sequence import WordSequence # pylint: disable=unused-variable

x_data, _, ws = pickle.load(open('chatbot.pkl', 'rb'))
x_data, _ = pickle.load(open('chatbot.pkl', 'rb'))
ws = pickle.load(open('ws.pkl', 'rb'))

for x in x_data[:5]:
print(' '.join(x))
Expand All @@ -44,16 +44,7 @@ def test(bidirectional, cell_type, depth,
batch_size=1,
mode='decode',
beam_width=0,
bidirectional=bidirectional,
cell_type=cell_type,
depth=depth,
attention_type=attention_type,
use_residual=use_residual,
use_dropout=use_dropout,
parallel_iterations=1,
time_major=time_major,
hidden_units=hidden_units,
share_embedding=True
**params
)
init = tf.global_variables_initializer()

Expand Down Expand Up @@ -91,20 +82,9 @@ def test(bidirectional, cell_type, depth,


def main():
"""入口程序,开始测试不同参数组合"""
random.seed(0)
np.random.seed(0)
tf.set_random_seed(0)
test(
bidirectional=True,
cell_type='lstm',
depth=2,
attention_type='Bahdanau',
use_residual=False,
use_dropout=False,
time_major=False,
hidden_units=512
)
"""入口程序"""
import json
test(json.load(open('params.json')))


if __name__ == '__main__':
Expand Down
34 changes: 7 additions & 27 deletions chatbot/test_anti.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
sys.path.append('..')


def test(bidirectional, cell_type, depth,
attention_type, use_residual, use_dropout, time_major, hidden_units):
def test(params):
"""测试不同参数在生成的假数据上的运行结果"""

from sequence_to_sequence import SequenceToSequence
from data_utils import batch_flow
from word_sequence import WordSequence # pylint: disable=unused-variable

x_data, _, ws = pickle.load(open('chatbot.pkl', 'rb'))
x_data, _ = pickle.load(open('chatbot.pkl', 'rb'))
ws = pickle.load(open('ws.pkl', 'rb'))

for x in x_data[:5]:
print(' '.join(x))
Expand All @@ -44,16 +44,7 @@ def test(bidirectional, cell_type, depth,
batch_size=1,
mode='decode',
beam_width=0,
bidirectional=bidirectional,
cell_type=cell_type,
depth=depth,
attention_type=attention_type,
use_residual=use_residual,
use_dropout=use_dropout,
parallel_iterations=1,
time_major=time_major,
hidden_units=hidden_units,
share_embedding=True
**params
)
init = tf.global_variables_initializer()

Expand Down Expand Up @@ -91,20 +82,9 @@ def test(bidirectional, cell_type, depth,


def main():
"""入口程序,开始测试不同参数组合"""
random.seed(0)
np.random.seed(0)
tf.set_random_seed(0)
test(
bidirectional=True,
cell_type='lstm',
depth=2,
attention_type='Bahdanau',
use_residual=False,
use_dropout=False,
time_major=False,
hidden_units=512
)
"""入口程序"""
import json
test(json.load(open('params.json')))


if __name__ == '__main__':
Expand Down
44 changes: 7 additions & 37 deletions chatbot/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
sys.path.append('..')


def test(bidirectional, cell_type, depth,
attention_type, use_residual, use_dropout, time_major, hidden_units):
def test(params):
"""测试不同参数在生成的假数据上的运行结果"""

from sequence_to_sequence import SequenceToSequence
from data_utils import batch_flow
from word_sequence import WordSequence # pylint: disable=unused-variable

_, _, ws = pickle.load(open('chatbot.pkl', 'rb'))
ws = pickle.load(open('ws.pkl', 'rb'))

# for x in x_data[:5]:
# print(' '.join(x))
Expand All @@ -45,16 +44,7 @@ def test(bidirectional, cell_type, depth,
batch_size=1,
mode='decode',
beam_width=12,
bidirectional=bidirectional,
cell_type=cell_type,
depth=depth,
attention_type=attention_type,
use_residual=use_residual,
use_dropout=use_dropout,
parallel_iterations=1,
time_major=time_major,
hidden_units=hidden_units,
share_embedding=True
**params
)
init = tf.global_variables_initializer()
sess_rl = tf.Session(config=config)
Expand All @@ -69,16 +59,7 @@ def test(bidirectional, cell_type, depth,
batch_size=1,
mode='decode',
beam_width=12,
bidirectional=bidirectional,
cell_type=cell_type,
depth=depth,
attention_type=attention_type,
use_residual=use_residual,
use_dropout=use_dropout,
parallel_iterations=1,
time_major=time_major,
hidden_units=hidden_units,
share_embedding=True
**params
)
init = tf.global_variables_initializer()
sess = tf.Session(config=config)
Expand Down Expand Up @@ -118,20 +99,9 @@ def test(bidirectional, cell_type, depth,


def main():
"""入口程序,开始测试不同参数组合"""
random.seed(0)
np.random.seed(0)
tf.set_random_seed(0)
test(
bidirectional=True,
cell_type='lstm',
depth=2,
attention_type='Bahdanau',
use_residual=False,
use_dropout=False,
time_major=False,
hidden_units=512
)
"""入口程序"""
import json
test(json.load(open('params.json')))


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 1210ee3

Please sign in to comment.