1
+ import numpy as np
2
+ import torch
3
+ import 两数相加
4
+
5
+ # 定义预测函数
6
+ def predict (x ):
7
+ model .eval ()
8
+ mask_pad_x = 两数相加 .mask_pad (x )
9
+ target = [两数相加 .vocab_y ['<SOS>' ]] + [两数相加 .vocab_y ['<PAD>' ]] * 9 # 初始化输出,这个是固定值 #49->9
10
+ target = torch .LongTensor (target ).unsqueeze (0 ) # 增加一个维度,shape变为[1,10]
11
+ target = target .to (两数相加 .device )
12
+ x = model .embed_x (x )
13
+ # 编码层计算,维度不变
14
+ x = model .encoder (x , mask_pad_x )
15
+ # 遍历生成第1个词到第9个词
16
+ for i in range (9 ): #49->9
17
+ y = target
18
+ mask_tril_y = 两数相加 .mask_tril (y ) # 上三角遮盖
19
+ # y编码,添加位置信息
20
+ y = model .embed_y (y )
21
+ # 解码层计算,维度不变
22
+ y = model .decoder (x , y , mask_pad_x , mask_tril_y )
23
+ out = model .fc_out (y )
24
+ # 取出当前词的输出
25
+ out = out [:,i ,:]
26
+ out = out .argmax (dim = 1 ).detach ()
27
+ # 以当前词预测下一个词,填到结果中
28
+ target [:,i + 1 ] = out
29
+ return target
30
+
31
+ # 测试
32
+ def model_test (x , y ):
33
+ x , y = x .to (两数相加 .device ), y .to (两数相加 .device )
34
+ #直接求答案
35
+ answer = '' .join ([两数相加 .vocab_yr [i ] for i in predict (x .unsqueeze (0 ))[0 ].tolist ()])
36
+
37
+ #这里推理已经结束,下面代码主要用于比较正确性
38
+ question = '' .join ([两数相加 .vocab_xr [i ] for i in x .tolist ()])
39
+ correct_answer = '' .join ([两数相加 .vocab_yr [i ] for i in y .tolist ()])
40
+
41
+ #把问题和答案中数字无关的字符都去掉
42
+ question_s = question .strip ('<SOS>PADE' )
43
+ answer_s = answer [:13 ].strip ('<SOS>PADE' ) # 偶尔答案中会在结束符后面还生成一些数字,忽略,增加成功率
44
+ correct_answer_s = correct_answer .strip ('<SOS>PADE' )
45
+
46
+ if answer_s == correct_answer_s :
47
+ is_correct = '预测正确'
48
+ else :
49
+ is_correct = '错误,正确答案是:' + correct_answer_s
50
+
51
+ print ("问题:" , question_s ,'预测答案:' , answer_s , is_correct )
52
+
53
+ # 两数相加测试
54
+ def get_data (a , opt , b ):
55
+ #为代码简单,不是最优写法
56
+ x = list (a ) + [opt ] + list (b )
57
+ y = list (str (eval (a + opt + b )))
58
+ # 加上首尾符号
59
+ x = ['<SOS>' ] + x + ['<EOS>' ]
60
+ y = ['<SOS>' ] + y + ['<EOS>' ]
61
+
62
+ # 补PAD,直到固定长度
63
+ x = x + ['<PAD>' ] * 10 #50->10
64
+ y = y + ['<PAD>' ] * (11 ) #51->11
65
+ x = x [:10 ] #50->10
66
+ y = y [:11 ] #51->11
67
+
68
+ # 编码成数据
69
+ x = [两数相加 .vocab_x [i ] for i in x ]
70
+ y = [两数相加 .vocab_y [i ] for i in y ]
71
+ # 转Tensor
72
+ x = torch .LongTensor (x )
73
+ y = torch .LongTensor (y )
74
+ return x , y
75
+
76
+
77
+ if __name__ == '__main__' :
78
+ # 用Transformer类定义一个模型model
79
+ model = 两数相加 .Transformer ()
80
+
81
+ #加载已经训练好的模型
82
+ model .load_state_dict (torch .load ("./model_plus_final.pth" ))
83
+ model = model .to (两数相加 .device )
84
+ x1 , y1 = get_data ('123' , '+' , '456' )
85
+ x2 , y2 = get_data ('111' , '*' , '111' )
86
+ x4 , y4 = get_data ('987' , '-' , '321' )
87
+
88
+ model_test (x1 , y1 )
89
+ model_test (x2 , y2 )
90
+ model_test (x4 , y4 )
0 commit comments