1
- < << << << HEAD
2
1
import numpy as np
3
2
import torch
4
3
import 两数相加
@@ -88,95 +87,4 @@ def get_data(a, opt, b):
88
87
89
88
model_test (x1 , y1 )
90
89
model_test (x2 , y2 )
91
- == == == =
92
- import numpy as np
93
- import torch
94
- import 两数相加
95
-
96
- # 定义预测函数
97
- def predict (x ):
98
- model .eval ()
99
- mask_pad_x = 两数相加 .mask_pad (x )
100
- target = [两数相加 .vocab_y ['<SOS>' ]] + [两数相加 .vocab_y ['<PAD>' ]] * 9 # 初始化输出,这个是固定值 #49->9
101
- target = torch .LongTensor (target ).unsqueeze (0 ) # 增加一个维度,shape变为[1,10]
102
- target = target .to (两数相加 .device )
103
- x = model .embed_x (x )
104
- # 编码层计算,维度不变
105
- x = model .encoder (x , mask_pad_x )
106
- # 遍历生成第1个词到第9个词
107
- for i in range (9 ): #49->9
108
- y = target
109
- mask_tril_y = 两数相加 .mask_tril (y ) # 上三角遮盖
110
- # y编码,添加位置信息
111
- y = model .embed_y (y )
112
- # 解码层计算,维度不变
113
- y = model .decoder (x , y , mask_pad_x , mask_tril_y )
114
- out = model .fc_out (y )
115
- # 取出当前词的输出
116
- out = out [:,i ,:]
117
- out = out .argmax (dim = 1 ).detach ()
118
- # 以当前词预测下一个词,填到结果中
119
- target [:,i + 1 ] = out
120
- return target
121
-
122
- # 测试
123
- def model_test (x , y ):
124
- x , y = x .to (两数相加 .device ), y .to (两数相加 .device )
125
- #直接求答案
126
- answer = '' .join ([两数相加 .vocab_yr [i ] for i in predict (x .unsqueeze (0 ))[0 ].tolist ()])
127
-
128
- #这里推理已经结束,下面代码主要用于比较正确性
129
- question = '' .join ([两数相加 .vocab_xr [i ] for i in x .tolist ()])
130
- correct_answer = '' .join ([两数相加 .vocab_yr [i ] for i in y .tolist ()])
131
-
132
- #把问题和答案中数字无关的字符都去掉
133
- question_s = question .strip ('<SOS>PADE' )
134
- answer_s = answer [:13 ].strip ('<SOS>PADE' ) # 偶尔答案中会在结束符后面还生成一些数字,忽略,增加成功率
135
- correct_answer_s = correct_answer .strip ('<SOS>PADE' )
136
-
137
- if answer_s == correct_answer_s :
138
- is_correct = '预测正确'
139
- else :
140
- is_correct = '错误,正确答案是:' + correct_answer_s
141
-
142
- print ("问题:" , question_s ,'预测答案:' , answer_s , is_correct )
143
-
144
- # 两数相加测试
145
- def get_data (a , opt , b ):
146
- #为代码简单,不是最优写法
147
- x = list (a ) + [opt ] + list (b )
148
- y = list (str (eval (a + opt + b )))
149
- # 加上首尾符号
150
- x = ['<SOS>' ] + x + ['<EOS>' ]
151
- y = ['<SOS>' ] + y + ['<EOS>' ]
152
-
153
- # 补PAD,直到固定长度
154
- x = x + ['<PAD>' ] * 10 #50->10
155
- y = y + ['<PAD>' ] * (11 ) #51->11
156
- x = x [:10 ] #50->10
157
- y = y [:11 ] #51->11
158
-
159
- # 编码成数据
160
- x = [两数相加 .vocab_x [i ] for i in x ]
161
- y = [两数相加 .vocab_y [i ] for i in y ]
162
- # 转Tensor
163
- x = torch .LongTensor (x )
164
- y = torch .LongTensor (y )
165
- return x , y
166
-
167
-
168
- if __name__ == '__main__' :
169
- # 用Transformer类定义一个模型model
170
- model = 两数相加 .Transformer ()
171
-
172
- #加载已经训练好的模型
173
- model .load_state_dict (torch .load ("./model_plus_final.pth" ))
174
- model = model .to (两数相加 .device )
175
- x1 , y1 = get_data ('123' , '+' , '456' )
176
- x2 , y2 = get_data ('111' , '*' , '111' )
177
- x4 , y4 = get_data ('987' , '-' , '321' )
178
-
179
- model_test (x1 , y1 )
180
- model_test (x2 , y2 )
181
- > >> >> >> a9099543f83bda65ec438209ad6dfa2a823f0334
182
90
model_test (x4 , y4 )
0 commit comments