Skip to content

Commit 04839ad

Browse files
committed
Create script to pull results from experiments
1 parent 5c10527 commit 04839ad

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

silnlp/nmt/exp_summary.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import argparse
2+
import glob
3+
import os
4+
import re
5+
6+
import pandas as pd
7+
from openpyxl import Workbook
8+
from openpyxl.styles import Alignment, Font, PatternFill
9+
from openpyxl.utils import get_column_letter
10+
11+
from .config import get_mt_exp_dir
12+
13+
chap_num = 0
14+
15+
16+
def extract_data(filename, metrics, target_book, header_row=5) -> dict:
17+
global chap_num
18+
19+
metrics = [m.lower() for m in metrics]
20+
df = pd.read_excel(filename, header=header_row)
21+
df.columns = [col.strip().lower() for col in df.columns]
22+
23+
result = {}
24+
for _, row in df.iterrows():
25+
vref = row["vref"]
26+
m = re.match(r"([A-Za-z]+)\s+(\d+)", str(vref))
27+
28+
book_name, chap = m.groups()
29+
if book_name != target_book:
30+
continue
31+
32+
if int(chap) > chap_num:
33+
chap_num = int(chap)
34+
35+
values = []
36+
for metric in metrics:
37+
if metric in row:
38+
values.append(row[metric])
39+
else:
40+
print("Warning: {metric} is not calculated in {filename}")
41+
values.append(None)
42+
43+
result[int(chap)] = values
44+
return result
45+
46+
47+
def flatten_dict(data, metrics, chapters) -> list:
48+
global chap_num
49+
50+
res = []
51+
for lang_pair in data:
52+
for chap in range(1, chap_num + 1):
53+
row = [lang_pair, chap]
54+
row.extend([None, None, None] * len(metrics) * len(data[lang_pair]))
55+
row.extend([None] * len(chapters))
56+
row.extend([None] * (1 + len(metrics)))
57+
58+
for res_chap in data[lang_pair]:
59+
if chap in data[lang_pair][res_chap]:
60+
for m in range(len(metrics)):
61+
index_m = 3 + 1 + len(metrics) + chapters.index(res_chap) * (len(metrics) * 3 + 1) + m * 3
62+
row[index_m] = data[lang_pair][res_chap][chap][m]
63+
res.append(row)
64+
return res
65+
66+
67+
def create_xlsx(rows, metrics, chapters, output_path):
68+
global chap_num
69+
70+
wb = Workbook()
71+
ws = wb.active
72+
73+
num_col = len(metrics) * 3 + 1
74+
groups = [("language pair", 1), ("Chapter", 1), ("Baseline", (1 + len(metrics)))]
75+
for chap in chapters:
76+
groups.append((chap, num_col))
77+
78+
col = 1
79+
for header, span in groups:
80+
start_col = get_column_letter(col)
81+
end_col = get_column_letter(col + span - 1)
82+
ws.merge_cells(f"{start_col}1:{end_col}1")
83+
ws.cell(row=1, column=col, value=header)
84+
col += span
85+
86+
sub_headers = []
87+
baseline_headers = []
88+
89+
for i, metric in enumerate(metrics):
90+
if i == 0:
91+
baseline_headers.append("rank")
92+
sub_headers.append("rank")
93+
baseline_headers.append(metric)
94+
sub_headers.append(metric)
95+
sub_headers.append("diff (prev)")
96+
sub_headers.append("diff (start)")
97+
98+
for i, baseline_header in enumerate(baseline_headers):
99+
ws.cell(row=2, column=3 + i, value=baseline_header)
100+
101+
col = 3 + len(metrics) + 1
102+
for _ in range(len(groups) - 2):
103+
for i, sub_header in enumerate(sub_headers):
104+
ws.cell(row=2, column=col + i, value=sub_header)
105+
106+
col += len(sub_headers)
107+
for row in rows:
108+
ws.append(row)
109+
110+
for row_idx in [1, 2]:
111+
for col in range(1, ws.max_column + 1):
112+
ws.cell(row=row_idx, column=col).font = Font(bold=True)
113+
ws.cell(row=row_idx, column=col).alignment = Alignment(horizontal="center", vertical="center")
114+
115+
ws.merge_cells(start_row=1, start_column=1, end_row=2, end_column=1)
116+
ws.merge_cells(start_row=1, start_column=2, end_row=2, end_column=2)
117+
ws.cell(row=1, column=1).alignment = Alignment(wrap_text=True, horizontal="center", vertical="center")
118+
119+
cur_lang_pair = 3
120+
for row_idx in range(3, ws.max_row + 1):
121+
start_col = 3 + len(metrics) + 1
122+
end_col = ws.max_column
123+
124+
while start_col < end_col:
125+
start_col += 1
126+
if ws.cell(row=row_idx, column=start_col).value is None:
127+
for col in range(start_col - 1, ws.max_column + 1):
128+
ws.cell(row=row_idx, column=col).fill = PatternFill(
129+
fill_type="solid", start_color="CCCCCC", end_color="CCCCCC"
130+
)
131+
break
132+
133+
col_letter = get_column_letter(start_col)
134+
ws.cell(row=row_idx, column=start_col - 1).value = (
135+
f"=RANK.EQ({col_letter}{row_idx}, INDEX({col_letter}:{col_letter}, \
136+
INT((ROW({col_letter}{row_idx})-3)/{chap_num})*{chap_num}+3):INDEX({col_letter}:{col_letter}, \
137+
INT((ROW({col_letter}{row_idx})-3)/{chap_num})*{chap_num}+{chap_num}+2), 0)"
138+
)
139+
140+
for i in range(1, len(metrics) + 1):
141+
start_letter = get_column_letter(3 + i)
142+
143+
diff_prev_col = start_col + 1
144+
diff_start_col = start_col + 2
145+
146+
prev_letter = (
147+
start_letter
148+
if diff_prev_col <= 3 + len(metrics) + 1 + 3 * len(metrics)
149+
else get_column_letter(diff_prev_col - 1 - 1 - 3 * len(metrics))
150+
)
151+
cur_letter = get_column_letter(diff_prev_col - 1)
152+
153+
ws.cell(row=row_idx, column=diff_prev_col).value = f"={cur_letter}{row_idx}-{prev_letter}{row_idx}"
154+
ws.cell(row=row_idx, column=diff_start_col).value = f"={cur_letter}{row_idx}-{start_letter}{row_idx}"
155+
156+
start_col += 3
157+
158+
if ws.cell(row=row_idx, column=1).value != ws.cell(row=cur_lang_pair, column=1).value:
159+
ws.merge_cells(start_row=cur_lang_pair, start_column=1, end_row=row_idx - 1, end_column=1)
160+
cur_lang_pair = row_idx
161+
elif row_idx == ws.max_row:
162+
ws.merge_cells(start_row=cur_lang_pair, start_column=1, end_row=row_idx, end_column=1)
163+
164+
wb.save(output_path)
165+
166+
167+
def main() -> None:
168+
global chap_num
169+
170+
# TODO: Add args for books, metrics, key word, baseline
171+
parser = argparse.ArgumentParser(description="Pull results")
172+
parser.add_argument("exp1", type=str, help="Experiment folder")
173+
args = parser.parse_args()
174+
175+
trained_books = ["MRK"]
176+
target_book = ["MAT"]
177+
all_books = trained_books + target_book
178+
179+
metrics = ["chrf3", "confidence"]
180+
181+
key_word = "conf"
182+
183+
exp1_name = args.exp1
184+
exp1_dir = get_mt_exp_dir(exp1_name)
185+
186+
folder_name = "+".join(all_books)
187+
os.makedirs(os.path.join(exp1_dir, "a_result_folder"), exist_ok=True)
188+
output_path = os.path.join(exp1_dir, "a_result_folder", f"{folder_name}.xlsx")
189+
190+
data = {}
191+
chapters = []
192+
193+
for lang_pair in os.listdir(exp1_dir):
194+
lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)")
195+
if not lang_pattern.match(lang_pair):
196+
continue
197+
198+
data[lang_pair] = {}
199+
prefix = "+".join(all_books)
200+
pattern = re.compile(rf"^{re.escape(prefix)}_{key_word}_order_(\d+)_ch$")
201+
202+
for groups in os.listdir(os.path.join(exp1_dir, lang_pair)):
203+
m = pattern.match(os.path.basename(groups))
204+
if m:
205+
base_name = "diff_predictions"
206+
folder_path = os.path.join(exp1_dir, lang_pair, os.path.basename(groups))
207+
diff_pred_file = glob.glob(os.path.join(folder_path, f"{base_name}*"))
208+
if diff_pred_file:
209+
r = extract_data(diff_pred_file[0], metrics, target_book[0])
210+
data[lang_pair][int(m.group(1))] = r
211+
chapters.append(int(m.group(1)))
212+
if int(m.group(1)) > chap_num:
213+
chap_num = int(m.group(1))
214+
else:
215+
print(os.path.basename(groups) + " has no diff_predictions file.")
216+
217+
chapters = sorted(set(chapters))
218+
print("Writing data...")
219+
rows = flatten_dict(data, metrics, chapters)
220+
create_xlsx(rows, metrics, chapters, output_path)
221+
print(f"Result is in {output_path}")
222+
223+
224+
if __name__ == "__main__":
225+
main()

0 commit comments

Comments
 (0)