Skip to content

Commit fd06117

Browse files
committed
推理代码,加载已有模型
1 parent 01b10df commit fd06117

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

model_plus_final.pth

306 KB
Binary file not shown.

四则运算_推理代码.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
import torch
3+
import 两数相加
4+
5+
# 定义预测函数
6+
def predict(x):
7+
model.eval()
8+
mask_pad_x = 两数相加.mask_pad(x)
9+
target = [两数相加.vocab_y['<SOS>']] + [两数相加.vocab_y['<PAD>']] * 9 # 初始化输出,这个是固定值 #49->9
10+
target = torch.LongTensor(target).unsqueeze(0) # 增加一个维度,shape变为[1,10]
11+
target = target.to(两数相加.device)
12+
x = model.embed_x(x)
13+
# 编码层计算,维度不变
14+
x = model.encoder(x, mask_pad_x)
15+
# 遍历生成第1个词到第9个词
16+
for i in range(9): #49->9
17+
y = target
18+
mask_tril_y = 两数相加.mask_tril(y) # 上三角遮盖
19+
# y编码,添加位置信息
20+
y = model.embed_y(y)
21+
# 解码层计算,维度不变
22+
y = model.decoder(x, y, mask_pad_x, mask_tril_y)
23+
out = model.fc_out(y)
24+
# 取出当前词的输出
25+
out = out[:,i,:]
26+
out = out.argmax(dim=1).detach()
27+
# 以当前词预测下一个词,填到结果中
28+
target[:,i + 1] = out
29+
return target
30+
31+
# 测试
32+
def model_test(x, y):
33+
x, y = x.to(两数相加.device), y.to(两数相加.device)
34+
#直接求答案
35+
answer = ''.join([两数相加.vocab_yr[i] for i in predict(x.unsqueeze(0))[0].tolist()])
36+
37+
#这里推理已经结束,下面代码主要用于比较正确性
38+
question = ''.join([两数相加.vocab_xr[i] for i in x.tolist()])
39+
correct_answer = ''.join([两数相加.vocab_yr[i] for i in y.tolist()])
40+
41+
#把问题和答案中数字无关的字符都去掉
42+
question_s = question.strip('<SOS>PADE')
43+
answer_s = answer[:13].strip('<SOS>PADE') # 偶尔答案中会在结束符后面还生成一些数字,忽略,增加成功率
44+
correct_answer_s = correct_answer.strip('<SOS>PADE')
45+
46+
if answer_s == correct_answer_s:
47+
is_correct = '预测正确'
48+
else:
49+
is_correct = '错误,正确答案是:' + correct_answer_s
50+
51+
print("问题:", question_s,'预测答案:', answer_s, is_correct)
52+
53+
# 两数相加测试
54+
def get_data(a, opt, b):
55+
#为代码简单,不是最优写法
56+
x = list(a) + [opt] + list(b)
57+
y = list(str(eval(a+opt+b)))
58+
# 加上首尾符号
59+
x = ['<SOS>'] + x + ['<EOS>']
60+
y = ['<SOS>'] + y + ['<EOS>']
61+
62+
# 补PAD,直到固定长度
63+
x = x + ['<PAD>'] * 10 #50->10
64+
y = y + ['<PAD>'] * (11) #51->11
65+
x = x[:10] #50->10
66+
y = y[:11] #51->11
67+
68+
# 编码成数据
69+
x = [两数相加.vocab_x[i] for i in x]
70+
y = [两数相加.vocab_y[i] for i in y]
71+
# 转Tensor
72+
x = torch.LongTensor(x)
73+
y = torch.LongTensor(y)
74+
return x, y
75+
76+
77+
if __name__ == '__main__':
78+
# 用Transformer类定义一个模型model
79+
model = 两数相加.Transformer()
80+
81+
#加载已经训练好的模型
82+
model.load_state_dict(torch.load("./model_plus_final.pth"))
83+
model = model.to(两数相加.device)
84+
x1, y1 = get_data('123', '+', '456')
85+
x2, y2 = get_data('111', '*', '111')
86+
x4, y4 = get_data('987', '-', '321')
87+
88+
model_test(x1, y1)
89+
model_test(x2, y2)
90+
model_test(x4, y4)

0 commit comments

Comments
 (0)