Skip to content

Commit

Permalink
en2zh works
Browse files Browse the repository at this point in the history
  • Loading branch information
qhduan committed Feb 9, 2018
1 parent aa3eb86 commit ff75f26
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 88 deletions.
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,5 @@
*.py[cod]
.ipynb_checkpoints
__pycache__

*.tmx
*.gz
bak
*.pkl
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
- [ ] 后续这个repo会作为一个基础完成一个dialog system
- seq2seq模型至少可以作为通用NER实现(截止2018年初,最好的NER应该还是bi-LSTM + CRF)

# Known issues

residsual没应用到decoder上,那个部分可能还有问题

# TensorFlow alert

test only tensorflow == 1.4.1
Expand Down
4 changes: 4 additions & 0 deletions en2zh/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

*.tmx
*.gz
*.pkl
22 changes: 18 additions & 4 deletions en2zh/README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@

# 英汉翻译测试

## 1、下载数据

下载页面

http://opus.nlpl.eu/OpenSubtitles2016.php
http://opus.nlpl.eu/OpenSubtitles2018.php

下载链接:

下载链接(不知道能不能直接用):
wget -O "en-zh_cn.tmx.gz" "http://opus.nlpl.eu/download.php?f=OpenSubtitles2018/en-zh_cn.tmx.gz"

http://opus.nlpl.eu/download.php?f=OpenSubtitles2016/en-zh_zh.tmx.gz
## 2、解压数据

这个数据是`英文-中文`的平行语聊

下载并解压数据,然后重命名为 `en-zh_zh.tmx`
解压缩:

gunzip -k en-zh_cn.tmx.gz

下载并解压数据,然后重命名为 `en-zh_zh.tmx` (如果有有必要)

这应该是一个xml格式(在`linux`下可以用`head`命令查看下是否正确)

## 3、预处理数据

运行 `extract_tmx.py` 得到 `data.pkl`

## 4、训练数据

运行 `train.py` 训练(默认到`/tmp/s2ss_en2zh`目录)

## 5、测试数据(测试翻译)

运行 `test.py` 查看测试结果
33 changes: 19 additions & 14 deletions en2zh/extract_tmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
把tmx(xml)的数据解开,分词,然后保存到data.pkl
"""

import re
# import re
import sys
import pickle
import xml.etree.ElementTree as ET
Expand All @@ -20,7 +20,7 @@ def main():
from word_sequence import WordSequence

x_data, y_data = [], []
tree = ET.parse('en-zh_zh.tmx')
tree = ET.parse('en-zh_cn.tmx')
root = tree.getroot()
body = root.find('body')
for tu in tqdm(body.findall('tu')):
Expand All @@ -29,7 +29,7 @@ def main():
for tuv in tu.findall('tuv'):
if list(tuv.attrib.values())[0] == 'en':
en += tuv.find('seg').text
elif list(tuv.attrib.values())[0] == 'zh_zh':
elif list(tuv.attrib.values())[0] == 'zh_cn':
zh += tuv.find('seg').text

if en and zh:
Expand All @@ -43,28 +43,33 @@ def main():

print('tokenize')

def en_tokenize(text):
# text = re.sub('[\((][^\))]+[\))]', '', text)
return nltk.word_tokenize(text.lower())

x_data = [
nltk.word_tokenize(x.lower())
en_tokenize(x)
for x in tqdm(x_data)
]

def zh_tokenize(text):
text = text.replace(',', ',')
text = text.replace('。', '.')
text = text.replace('?', '?')
text = text.replace('!', '!')
text = re.sub(r'[^\u4e00-\u9fff,\.\?\!…《》]', '', text)
text = text.strip()
text = jieba.lcut(text)
# text = text.replace(',', ',')
# text = text.replace('。', '.')
# text = text.replace('?', '?')
# text = text.replace('!', '!')
# text = text.replace(':', ':')
# text = re.sub(r'[^\u4e00-\u9fff,\.\?\!…《》]:', '', text)
# text = text.strip()
text = jieba.lcut(text.lower())
return text

y_data = [
zh_tokenize(y.lower())
zh_tokenize(y)
for y in tqdm(y_data)
]

data = list(zip(x_data, y_data))
data = [(x, y) for x, y in data if len(x) < 10 and len(y) < 10]
data = [(x, y) for x, y in data if len(x) < 15 and len(y) < 15]

x_data, y_data = [x[0] for x in data], [x[1] for x in data]

Expand All @@ -84,7 +89,7 @@ def zh_tokenize(text):

pickle.dump(
(x_data, y_data, ws_input, ws_target),
open('data.pkl', 'wb')
open('en-zh_cn.pkl', 'wb')
)

print('done')
Expand Down
13 changes: 9 additions & 4 deletions en2zh/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@


def test(bidirectional, cell_type, depth,
attention_type, use_residual, use_dropout):
attention_type, use_residual, use_dropout, time_major, hidden_units):
"""测试不同参数在生成的假数据上的运行结果"""

from sequence_to_sequence import SequenceToSequence
from sequence_to_sequence import batch_flow
from word_sequence import WordSequence # pylint: disable=unused-variable

x_data, _, ws_input, ws_target = pickle.load(open('data.pkl', 'rb'))
x_data, _, ws_input, ws_target = pickle.load(open('en-zh_cn.pkl', 'rb'))

for x in x_data[:5]:
print(' '.join(x))
Expand Down Expand Up @@ -49,7 +49,8 @@ def test(bidirectional, cell_type, depth,
use_residual=use_residual,
use_dropout=use_dropout,
parallel_iterations=1,
hidden_units=512 # for test
time_major=time_major,
hidden_units=hidden_units # for test
)
init = tf.global_variables_initializer()

Expand All @@ -64,6 +65,10 @@ def test(bidirectional, cell_type, depth,
x_test = [nltk.word_tokenize(user_text.lower())]
bar = batch_flow(x_test, x_test, ws_input, ws_target, 1)
x, xl, _, _ = next(bar)
# x = np.array([
# list(reversed(xx))
# for xx in x
# ])
print(x, xl)
pred = model_pred.predict(
sess,
Expand Down Expand Up @@ -175,7 +180,7 @@ def main():
random.seed(0)
np.random.seed(0)
tf.set_random_seed(0)
test(True, 'gru', 1, 'Bahdanau', False, True)
test(True, 'lstm', 2, 'Bahdanau', False, True, True, 64)


if __name__ == '__main__':
Expand Down
30 changes: 19 additions & 11 deletions en2zh/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@


def test(bidirectional, cell_type, depth,
attention_type, use_residual, use_dropout, time_major):
attention_type, use_residual, use_dropout, time_major, hidden_units):
"""测试不同参数在生成的假数据上的运行结果"""

from sequence_to_sequence import SequenceToSequence
from sequence_to_sequence import batch_flow
from word_sequence import WordSequence # pylint: disable=unused-variable

x_data, y_data, ws_input, ws_target = pickle.load(open('data.pkl', 'rb'))
x_data, y_data, ws_input, ws_target = pickle.load(
open('en-zh_cn.pkl', 'rb'))

# 获取一些假数据
# x_data, y_data, ws_input, ws_target = generate(size=10000)
Expand All @@ -30,9 +31,8 @@ def test(bidirectional, cell_type, depth,
split = int(len(x_data) * 0.8)
x_train, x_test, y_train, y_test = (
x_data[:split], x_data[split:], y_data[:split], y_data[split:])
n_epoch = 10
batch_size = 32
hidden_units = 128
n_epoch = 2
batch_size = 256
steps = int(len(x_train) / batch_size) + 1

config = tf.ConfigProto(
Expand Down Expand Up @@ -64,7 +64,7 @@ def test(bidirectional, cell_type, depth,
use_dropout=use_dropout,
parallel_iterations=64,
hidden_units=hidden_units,
optimizer='momentum',
optimizer='adam',
time_major=time_major
)
init = tf.global_variables_initializer()
Expand All @@ -83,10 +83,10 @@ def test(bidirectional, cell_type, depth,
for _ in bar:
x, xl, y, yl = next(flow)
# trick, reverse input
y = np.array([
list(reversed(yy))
for yy in y
])
# x = np.array([
# list(reversed(xx))
# for xx in x
# ])
cost = model.train(sess, x, xl, y, yl)
costs.append(cost)
bar.set_description('epoch {} loss={:.6f}'.format(
Expand Down Expand Up @@ -123,6 +123,10 @@ def test(bidirectional, cell_type, depth,
bar = batch_flow(x_test, y_test, ws_input, ws_target, 1)
t = 0
for x, xl, y, yl in bar:
# x = np.array([
# list(reversed(xx))
# for xx in x
# ])
pred = model_pred.predict(
sess,
np.array(x),
Expand Down Expand Up @@ -161,6 +165,10 @@ def test(bidirectional, cell_type, depth,
bar = batch_flow(x_test, y_test, ws_input, ws_target, 1)
t = 0
for x, xl, y, yl in bar:
# x = np.array([
# list(reversed(xx))
# for xx in x
# ])
pred = model_pred.predict(
sess,
np.array(x),
Expand All @@ -179,7 +187,7 @@ def main():
random.seed(0)
np.random.seed(0)
tf.set_random_seed(0)
test(True, 'gru', 2, 'Bahdanau', False, True, True)
test(True, 'lstm', 2, 'Bahdanau', False, True, True, 64)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit ff75f26

Please sign in to comment.