-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
141 lines (112 loc) · 4.32 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from abc import ABC, abstractmethod
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from grammer.utils import state_trace_exact_match, execute_simcode
sf = SmoothingFunction()
from grammer.utils import convert_tokens_to_code
class Metric(ABC):
"""
class represent metric to report
"""
def __init__(self):
self.reset()
@abstractmethod
def reset(self):
raise NotImplementedError()
@abstractmethod
def update(self, preds, targets):
raise NotImplementedError()
@abstractmethod
def __str__(self):
raise NotImplementedError()
@abstractmethod
def eval(self):
raise NotImplementedError()
class CorrectAnswersScore(Metric):
def __init__(self, tgt_vocab):
super(CorrectAnswersScore, self).__init__()
self._tgt_vocab = tgt_vocab
self.reset()
def reset(self):
self._state_scores = 0
self._correct_answers = 0
self._number_of_batches = 0
def update(self, preds, targets):
self._number_of_batches += 1
scores = 0
correct_answers = 0
for pred_seq, tgt_seq in zip(preds, targets):
pred_code = convert_tokens_to_code(pred_seq)
target_code = convert_tokens_to_code(tgt_seq)
try:
answer, state = execute_simcode(target_code, True)
pred_answer, pred_state = execute_simcode(pred_code, True)
scores += state_trace_exact_match(state, pred_state)
correct_answers += 1 if round(float(answer), 6) == round(float(pred_answer), 6) else 0
except Exception as e:
pass
self._state_scores += scores/len(preds)
self._correct_answers += correct_answers/len(preds)
def eval(self):
return (self._correct_answers/ self._number_of_batches , self._state_scores/ self._number_of_batches )
def __str__(self):
return f"Correct answers %.4f, State transitions score: %.4f" % self.eval()
class BleuScore(Metric):
def __init__(self):
super().__init__()
def reset(self):
self._n_batches = 0
self._bleu_scores_sum = 0
def update(self, preds, targets):
"""
preds need to have dimension of batch_size, number_of_tokens
targets need to have dimension of batch_size, number_of_tokens
"""
self._n_batches += 1
candidates = preds
refs = [[target] for target in targets]
bscore = corpus_bleu(refs, candidates,
smoothing_function=sf.method1,
weights=(0.5, 0.5))
self._bleu_scores_sum += bscore
def eval(self):
return self._bleu_scores_sum / self._n_batches
def __str__(self):
return f"Bleu: %.4f" % self.eval()
class BleuAndStateScore(Metric):
def __init__(self, tgt_vocab, alpha):
super().__init__()
self._alpha = alpha
self._tgt_vocab = tgt_vocab
self.reset()
def reset(self):
self._n_batches = 0
self._bleu_scores_sum = 0
self._state_scores = 0
self._number_of_batches = 0
def update(self, preds, targets):
"""
preds need to have dimension of batch_size, number_of_tokens
targets need to have dimension of batch_size, number_of_tokens
"""
self._n_batches += 1
candidates = preds
refs = [[target] for target in targets]
bscore = corpus_bleu(refs, candidates,
smoothing_function=sf.method1,
weights=(0.5, 0.5))
self._bleu_scores_sum += bscore
state_scores = 0
for pred_seq, tgt_seq in zip(preds, targets):
pred_code = convert_tokens_to_code(pred_seq)
target_code = convert_tokens_to_code(tgt_seq)
try:
_, state = execute_simcode(target_code, True)
_, pred_state = execute_simcode(pred_code, True)
state_scores += state_trace_exact_match(state, pred_state)
except Exception as e:
pass
self._state_scores += state_scores / len(preds)
def eval(self):
return ((1-self._alpha) * self._bleu_scores_sum + self._alpha * self._state_scores) / self._n_batches
def __str__(self):
return f"Bleu and state score: %.4f" % self.eval()