Skip to content

Commit c327c5c

Browse files
committed
小的修改
1 parent 889a0a8 commit c327c5c

File tree

1 file changed

+0
-92
lines changed

1 file changed

+0
-92
lines changed

四则运算_推理代码.py

-92
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
<<<<<<< HEAD
21
import numpy as np
32
import torch
43
import 两数相加
@@ -88,95 +87,4 @@ def get_data(a, opt, b):
8887

8988
model_test(x1, y1)
9089
model_test(x2, y2)
91-
=======
92-
import numpy as np
93-
import torch
94-
import 两数相加
95-
96-
# 定义预测函数
97-
def predict(x):
98-
model.eval()
99-
mask_pad_x = 两数相加.mask_pad(x)
100-
target = [两数相加.vocab_y['<SOS>']] + [两数相加.vocab_y['<PAD>']] * 9 # 初始化输出,这个是固定值 #49->9
101-
target = torch.LongTensor(target).unsqueeze(0) # 增加一个维度,shape变为[1,10]
102-
target = target.to(两数相加.device)
103-
x = model.embed_x(x)
104-
# 编码层计算,维度不变
105-
x = model.encoder(x, mask_pad_x)
106-
# 遍历生成第1个词到第9个词
107-
for i in range(9): #49->9
108-
y = target
109-
mask_tril_y = 两数相加.mask_tril(y) # 上三角遮盖
110-
# y编码,添加位置信息
111-
y = model.embed_y(y)
112-
# 解码层计算,维度不变
113-
y = model.decoder(x, y, mask_pad_x, mask_tril_y)
114-
out = model.fc_out(y)
115-
# 取出当前词的输出
116-
out = out[:,i,:]
117-
out = out.argmax(dim=1).detach()
118-
# 以当前词预测下一个词,填到结果中
119-
target[:,i + 1] = out
120-
return target
121-
122-
# 测试
123-
def model_test(x, y):
124-
x, y = x.to(两数相加.device), y.to(两数相加.device)
125-
#直接求答案
126-
answer = ''.join([两数相加.vocab_yr[i] for i in predict(x.unsqueeze(0))[0].tolist()])
127-
128-
#这里推理已经结束,下面代码主要用于比较正确性
129-
question = ''.join([两数相加.vocab_xr[i] for i in x.tolist()])
130-
correct_answer = ''.join([两数相加.vocab_yr[i] for i in y.tolist()])
131-
132-
#把问题和答案中数字无关的字符都去掉
133-
question_s = question.strip('<SOS>PADE')
134-
answer_s = answer[:13].strip('<SOS>PADE') # 偶尔答案中会在结束符后面还生成一些数字,忽略,增加成功率
135-
correct_answer_s = correct_answer.strip('<SOS>PADE')
136-
137-
if answer_s == correct_answer_s:
138-
is_correct = '预测正确'
139-
else:
140-
is_correct = '错误,正确答案是:' + correct_answer_s
141-
142-
print("问题:", question_s,'预测答案:', answer_s, is_correct)
143-
144-
# 两数相加测试
145-
def get_data(a, opt, b):
146-
#为代码简单,不是最优写法
147-
x = list(a) + [opt] + list(b)
148-
y = list(str(eval(a+opt+b)))
149-
# 加上首尾符号
150-
x = ['<SOS>'] + x + ['<EOS>']
151-
y = ['<SOS>'] + y + ['<EOS>']
152-
153-
# 补PAD,直到固定长度
154-
x = x + ['<PAD>'] * 10 #50->10
155-
y = y + ['<PAD>'] * (11) #51->11
156-
x = x[:10] #50->10
157-
y = y[:11] #51->11
158-
159-
# 编码成数据
160-
x = [两数相加.vocab_x[i] for i in x]
161-
y = [两数相加.vocab_y[i] for i in y]
162-
# 转Tensor
163-
x = torch.LongTensor(x)
164-
y = torch.LongTensor(y)
165-
return x, y
166-
167-
168-
if __name__ == '__main__':
169-
# 用Transformer类定义一个模型model
170-
model = 两数相加.Transformer()
171-
172-
#加载已经训练好的模型
173-
model.load_state_dict(torch.load("./model_plus_final.pth"))
174-
model = model.to(两数相加.device)
175-
x1, y1 = get_data('123', '+', '456')
176-
x2, y2 = get_data('111', '*', '111')
177-
x4, y4 = get_data('987', '-', '321')
178-
179-
model_test(x1, y1)
180-
model_test(x2, y2)
181-
>>>>>>> a9099543f83bda65ec438209ad6dfa2a823f0334
18290
model_test(x4, y4)

0 commit comments

Comments
 (0)