-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathget_math_results.py
More file actions
executable file
·100 lines (78 loc) · 3.19 KB
/
Copy pathget_math_results.py
File metadata and controls
executable file
·100 lines (78 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import json
from tqdm import tqdm, trange
from eval_math_rule.evaluation.grader import math_equal
from eval_math_rule.evaluation.parser import extract_answer, parse_ground_truth, strip_string
from collections import Counter
import multiprocessing
import queue
def math_equal_with_timeout(pred, gt_ans, timeout):
def target(result_queue):
try:
result_queue.put(math_equal(pred, gt_ans))
except Exception as e:
result_queue.put(e)
result_queue = multiprocessing.Queue()
process = multiprocessing.Process(target=target, args=(result_queue,))
process.start()
process.join(timeout)
if process.is_alive():
print(f"Timeout occurred for prediction: {pred}")
process.terminate()
process.join()
return False
try:
result = result_queue.get_nowait()
except queue.Empty:
print("Result queue timed out")
return False
if isinstance(result, Exception):
print(f"Error occurred: {result}")
return False
return result
def parallel_math_equal(all_pred, gt_ans, timeout=20):
results = []
for pred in all_pred:
results.append(math_equal_with_timeout(pred, gt_ans, timeout))
return results
def main(res_path, save=False, k=None, output_dir=None):
# args = parse_args()
with open(res_path, "r") as f:
lines = f.readlines()
data = [json.loads(line) for line in lines]
for example in tqdm(data):
# gt_cot, gt = parse_ground_truth(example, data_name="omni-math")
if "model_generation" not in example:
example["model_generation"] = example["model_output"]
if k is not None:
example["model_generation"] = example["model_generation"][:k]
gt_cot = example["answer"]
gt_ans = extract_answer(gt_cot, data_name="omni-math")
gt_cot = str(gt_cot).strip()
gt_ans = strip_string(gt_ans, skip_unit=False)
all_pred = [extract_answer(p, data_name="omni-math") for p in example["model_generation"]]
all_pred = [strip_string(p, skip_unit=False) for p in all_pred]
# all_eval = [math_equal(p, gt_ans) for p in all_pred]
all_eval = parallel_math_equal(all_pred, gt_ans, timeout=5)
effective_pred = [p for p, o in zip(all_pred, example["model_generation"]) if "boxed" in o]
if len(effective_pred) == 0:
effective_pred = all_pred
counter = Counter(effective_pred)
pred = counter.most_common(1)[0][0]
index = all_pred.index(pred)
eval = all_eval[index]
example["all_pred"] = all_pred
example["all_eval"] = all_eval
example["mv_pred"] = pred
example["mv_eval"] = eval
example["mv_index"] = index
acc = sum([example["mv_eval"] for example in data]) / len(data)
print(f"Accuracy: {acc:.3f}")
if save:
out_file = os.path.join(output_dir, "math_eval.jsonl")
with open(out_file, "w") as f:
for example in data:
f.write(json.dumps(example) + "\n")
metric_file= os.path.join(output_dir, "metrics.json")
with open(metric_file, "w") as f:
json.dump({"acc": acc}, f)