|
| 1 | +import argparse |
| 2 | +import glob |
| 3 | +import os |
| 4 | +import re |
| 5 | + |
| 6 | +import pandas as pd |
| 7 | +from openpyxl import Workbook |
| 8 | +from openpyxl.styles import Alignment, Font, PatternFill |
| 9 | +from openpyxl.utils import get_column_letter |
| 10 | + |
| 11 | +from .config import get_mt_exp_dir |
| 12 | + |
| 13 | +chap_num = 0 |
| 14 | + |
| 15 | + |
| 16 | +def extract_data(filename, metrics, target_book, header_row=5) -> dict: |
| 17 | + global chap_num |
| 18 | + |
| 19 | + metrics = [m.lower() for m in metrics] |
| 20 | + df = pd.read_excel(filename, header=header_row) |
| 21 | + df.columns = [col.strip().lower() for col in df.columns] |
| 22 | + |
| 23 | + result = {} |
| 24 | + for _, row in df.iterrows(): |
| 25 | + vref = row["vref"] |
| 26 | + m = re.match(r"([A-Za-z]+)\s+(\d+)", str(vref)) |
| 27 | + |
| 28 | + book_name, chap = m.groups() |
| 29 | + if book_name != target_book: |
| 30 | + continue |
| 31 | + |
| 32 | + if int(chap) > chap_num: |
| 33 | + chap_num = int(chap) |
| 34 | + |
| 35 | + values = [] |
| 36 | + for metric in metrics: |
| 37 | + if metric in row: |
| 38 | + values.append(row[metric]) |
| 39 | + else: |
| 40 | + print("Warning: {metric} is not calculated in {filename}") |
| 41 | + values.append(None) |
| 42 | + |
| 43 | + result[int(chap)] = values |
| 44 | + return result |
| 45 | + |
| 46 | + |
| 47 | +def flatten_dict(data, metrics, chapters) -> list: |
| 48 | + global chap_num |
| 49 | + |
| 50 | + res = [] |
| 51 | + for lang_pair in data: |
| 52 | + for chap in range(1, chap_num + 1): |
| 53 | + row = [lang_pair, chap] |
| 54 | + row.extend([None, None, None] * len(metrics) * len(data[lang_pair])) |
| 55 | + row.extend([None] * len(chapters)) |
| 56 | + row.extend([None] * (1 + len(metrics))) |
| 57 | + |
| 58 | + for res_chap in data[lang_pair]: |
| 59 | + if chap in data[lang_pair][res_chap]: |
| 60 | + for m in range(len(metrics)): |
| 61 | + index_m = 3 + 1 + len(metrics) + chapters.index(res_chap) * (len(metrics) * 3 + 1) + m * 3 |
| 62 | + row[index_m] = data[lang_pair][res_chap][chap][m] |
| 63 | + res.append(row) |
| 64 | + return res |
| 65 | + |
| 66 | + |
| 67 | +def create_xlsx(rows, metrics, chapters, output_path): |
| 68 | + global chap_num |
| 69 | + |
| 70 | + wb = Workbook() |
| 71 | + ws = wb.active |
| 72 | + |
| 73 | + num_col = len(metrics) * 3 + 1 |
| 74 | + groups = [("language pair", 1), ("Chapter", 1), ("Baseline", (1 + len(metrics)))] |
| 75 | + for chap in chapters: |
| 76 | + groups.append((chap, num_col)) |
| 77 | + |
| 78 | + col = 1 |
| 79 | + for header, span in groups: |
| 80 | + start_col = get_column_letter(col) |
| 81 | + end_col = get_column_letter(col + span - 1) |
| 82 | + ws.merge_cells(f"{start_col}1:{end_col}1") |
| 83 | + ws.cell(row=1, column=col, value=header) |
| 84 | + col += span |
| 85 | + |
| 86 | + sub_headers = [] |
| 87 | + baseline_headers = [] |
| 88 | + |
| 89 | + for i, metric in enumerate(metrics): |
| 90 | + if i == 0: |
| 91 | + baseline_headers.append("rank") |
| 92 | + sub_headers.append("rank") |
| 93 | + baseline_headers.append(metric) |
| 94 | + sub_headers.append(metric) |
| 95 | + sub_headers.append("diff (prev)") |
| 96 | + sub_headers.append("diff (start)") |
| 97 | + |
| 98 | + for i, baseline_header in enumerate(baseline_headers): |
| 99 | + ws.cell(row=2, column=3 + i, value=baseline_header) |
| 100 | + |
| 101 | + col = 3 + len(metrics) + 1 |
| 102 | + for _ in range(len(groups) - 2): |
| 103 | + for i, sub_header in enumerate(sub_headers): |
| 104 | + ws.cell(row=2, column=col + i, value=sub_header) |
| 105 | + |
| 106 | + col += len(sub_headers) |
| 107 | + for row in rows: |
| 108 | + ws.append(row) |
| 109 | + |
| 110 | + for row_idx in [1, 2]: |
| 111 | + for col in range(1, ws.max_column + 1): |
| 112 | + ws.cell(row=row_idx, column=col).font = Font(bold=True) |
| 113 | + ws.cell(row=row_idx, column=col).alignment = Alignment(horizontal="center", vertical="center") |
| 114 | + |
| 115 | + ws.merge_cells(start_row=1, start_column=1, end_row=2, end_column=1) |
| 116 | + ws.merge_cells(start_row=1, start_column=2, end_row=2, end_column=2) |
| 117 | + ws.cell(row=1, column=1).alignment = Alignment(wrap_text=True, horizontal="center", vertical="center") |
| 118 | + |
| 119 | + cur_lang_pair = 3 |
| 120 | + for row_idx in range(3, ws.max_row + 1): |
| 121 | + start_col = 3 + len(metrics) + 1 |
| 122 | + end_col = ws.max_column |
| 123 | + |
| 124 | + while start_col < end_col: |
| 125 | + start_col += 1 |
| 126 | + if ws.cell(row=row_idx, column=start_col).value is None: |
| 127 | + for col in range(start_col - 1, ws.max_column + 1): |
| 128 | + ws.cell(row=row_idx, column=col).fill = PatternFill( |
| 129 | + fill_type="solid", start_color="CCCCCC", end_color="CCCCCC" |
| 130 | + ) |
| 131 | + break |
| 132 | + |
| 133 | + col_letter = get_column_letter(start_col) |
| 134 | + ws.cell(row=row_idx, column=start_col - 1).value = ( |
| 135 | + f"=RANK.EQ({col_letter}{row_idx}, INDEX({col_letter}:{col_letter}, \ |
| 136 | + INT((ROW({col_letter}{row_idx})-3)/{chap_num})*{chap_num}+3):INDEX({col_letter}:{col_letter}, \ |
| 137 | + INT((ROW({col_letter}{row_idx})-3)/{chap_num})*{chap_num}+{chap_num}+2), 0)" |
| 138 | + ) |
| 139 | + |
| 140 | + for i in range(1, len(metrics) + 1): |
| 141 | + start_letter = get_column_letter(3 + i) |
| 142 | + |
| 143 | + diff_prev_col = start_col + 1 |
| 144 | + diff_start_col = start_col + 2 |
| 145 | + |
| 146 | + prev_letter = ( |
| 147 | + start_letter |
| 148 | + if diff_prev_col <= 3 + len(metrics) + 1 + 3 * len(metrics) |
| 149 | + else get_column_letter(diff_prev_col - 1 - 1 - 3 * len(metrics)) |
| 150 | + ) |
| 151 | + cur_letter = get_column_letter(diff_prev_col - 1) |
| 152 | + |
| 153 | + ws.cell(row=row_idx, column=diff_prev_col).value = f"={cur_letter}{row_idx}-{prev_letter}{row_idx}" |
| 154 | + ws.cell(row=row_idx, column=diff_start_col).value = f"={cur_letter}{row_idx}-{start_letter}{row_idx}" |
| 155 | + |
| 156 | + start_col += 3 |
| 157 | + |
| 158 | + if ws.cell(row=row_idx, column=1).value != ws.cell(row=cur_lang_pair, column=1).value: |
| 159 | + ws.merge_cells(start_row=cur_lang_pair, start_column=1, end_row=row_idx - 1, end_column=1) |
| 160 | + cur_lang_pair = row_idx |
| 161 | + elif row_idx == ws.max_row: |
| 162 | + ws.merge_cells(start_row=cur_lang_pair, start_column=1, end_row=row_idx, end_column=1) |
| 163 | + |
| 164 | + wb.save(output_path) |
| 165 | + |
| 166 | + |
| 167 | +def main() -> None: |
| 168 | + global chap_num |
| 169 | + |
| 170 | + # TODO: Add args for books, metrics, key word, baseline |
| 171 | + parser = argparse.ArgumentParser(description="Pull results") |
| 172 | + parser.add_argument("exp1", type=str, help="Experiment folder") |
| 173 | + args = parser.parse_args() |
| 174 | + |
| 175 | + trained_books = ["MRK"] |
| 176 | + target_book = ["MAT"] |
| 177 | + all_books = trained_books + target_book |
| 178 | + |
| 179 | + metrics = ["chrf3", "confidence"] |
| 180 | + |
| 181 | + key_word = "conf" |
| 182 | + |
| 183 | + exp1_name = args.exp1 |
| 184 | + exp1_dir = get_mt_exp_dir(exp1_name) |
| 185 | + |
| 186 | + folder_name = "+".join(all_books) |
| 187 | + os.makedirs(os.path.join(exp1_dir, "a_result_folder"), exist_ok=True) |
| 188 | + output_path = os.path.join(exp1_dir, "a_result_folder", f"{folder_name}.xlsx") |
| 189 | + |
| 190 | + data = {} |
| 191 | + chapters = [] |
| 192 | + |
| 193 | + for lang_pair in os.listdir(exp1_dir): |
| 194 | + lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)") |
| 195 | + if not lang_pattern.match(lang_pair): |
| 196 | + continue |
| 197 | + |
| 198 | + data[lang_pair] = {} |
| 199 | + prefix = "+".join(all_books) |
| 200 | + pattern = re.compile(rf"^{re.escape(prefix)}_{key_word}_order_(\d+)_ch$") |
| 201 | + |
| 202 | + for groups in os.listdir(os.path.join(exp1_dir, lang_pair)): |
| 203 | + m = pattern.match(os.path.basename(groups)) |
| 204 | + if m: |
| 205 | + base_name = "diff_predictions" |
| 206 | + folder_path = os.path.join(exp1_dir, lang_pair, os.path.basename(groups)) |
| 207 | + diff_pred_file = glob.glob(os.path.join(folder_path, f"{base_name}*")) |
| 208 | + if diff_pred_file: |
| 209 | + r = extract_data(diff_pred_file[0], metrics, target_book[0]) |
| 210 | + data[lang_pair][int(m.group(1))] = r |
| 211 | + chapters.append(int(m.group(1))) |
| 212 | + if int(m.group(1)) > chap_num: |
| 213 | + chap_num = int(m.group(1)) |
| 214 | + else: |
| 215 | + print(os.path.basename(groups) + " has no diff_predictions file.") |
| 216 | + |
| 217 | + chapters = sorted(set(chapters)) |
| 218 | + print("Writing data...") |
| 219 | + rows = flatten_dict(data, metrics, chapters) |
| 220 | + create_xlsx(rows, metrics, chapters, output_path) |
| 221 | + print(f"Result is in {output_path}") |
| 222 | + |
| 223 | + |
| 224 | +if __name__ == "__main__": |
| 225 | + main() |
0 commit comments