-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathevaluate.py
96 lines (74 loc) · 3.29 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import time
from collections import Counter
import pickle
from models.HMM import HMM
from models.CRF import CRFModel
from evaluating import Metrics
from operate_bilstm import BiLSTM_operator
from utils import save_model,flatten_lists
def hmm_train_eval(train_data,test_data,word2id,tag2id,remove_0=False):
"""hmm模型的评估与训练"""
print("hmm模型的评估与训练...")
train_word_lists,train_tag_lists = train_data
test_word_lists,test_tag_lists = test_data
hmm_model = HMM(len(tag2id),len(word2id))
hmm_model.train(train_word_lists,train_tag_lists,word2id,tag2id)
save_model(hmm_model,"./ckpts/hmm.pkl")
# 模型评估
pred_tag_lists = hmm_model.test(test_word_lists,word2id,tag2id)
metrics = Metrics(test_tag_lists,pred_tag_lists)
metrics.report_scores(dtype='HMM')
return pred_tag_lists
def crf_train_eval(train_data,test_data,remove_0=False):
"""crf模型的评估与训练"""
print("crf模型的评估与训练...")
train_word_lists,train_tag_lists = train_data
test_word_lists,test_tag_lists = test_data
crf_model = CRFModel()
crf_model.train(train_word_lists,train_tag_lists)
save_model(crf_model,"./ckpts/crf.pkl")
pred_tag_lists = crf_model.test(test_word_lists)
metrics = Metrics(test_tag_lists,pred_tag_lists)
metrics.report_scores(dtype='CRF')
return pred_tag_lists
def bilstm_train_and_eval(train_data,dev_data,test_data,word2id,tag2id,crf=True,remove_0=False):
"""bilstm模型的评估与训练..."""
if crf:
print("bilstm+crf模型的评估与训练...")
else:
print("bilstm模型的评估与训练...")
train_word_lists, train_tag_lists = train_data
dev_word_lists, dev_tag_lists = dev_data
test_word_lists, test_tag_lists = test_data
start = time.time()
vocab_size = len(word2id)
out_size = len(tag2id)
bilstm_operator = BiLSTM_operator(vocab_size,out_size,crf=crf)
# with open('./ckpts/bilstm.pkl','rb') as fout:
# bilstm_operator = pickle.load(fout)
model_name = "bilstm_crf" if crf else "bilstm"
print("start to train the {} ...".format(model_name))
bilstm_operator.train(train_word_lists,train_tag_lists,dev_word_lists,dev_tag_lists,word2id,tag2id)
save_model(bilstm_operator, "./ckpts/" + model_name + ".pkl")
print("训练完毕,共用时{}秒.".format(int(time.time() - start)))
print("评估{}模型中...".format(model_name))
pred_tag_lists, test_tag_lists = bilstm_operator.test(
test_word_lists, test_tag_lists, word2id, tag2id)
metrics = Metrics(test_tag_lists, pred_tag_lists, remove_0=remove_0)
dtype = 'Bi_LSTM+CRF' if crf else 'Bi_LSTM'
metrics.report_scores(dtype=dtype)
return pred_tag_lists
def ensemble_evaluate(results, targets, remove_O=False):
"""ensemble多个模型"""
for i in range(len(results)):
results[i] = flatten_lists(results[i])
pred_tags = []
for result in zip(*results):
ensemble_tag = Counter(result).most_common(1)[0][0]
pred_tags.append(ensemble_tag)
targets = flatten_lists(targets)
assert len(pred_tags) == len(targets)
print("Ensemble 四个模型的结果如下:")
metrics = Metrics(targets, pred_tags, remove_0=remove_O)
metrics.report_scores(dtype='ensembel')
# metrics.report_confusion_matrix()