Skip to content

Commit d024a19

Browse files
committed
Fix bugs, improve arg parse, update baseline input
1 parent 04839ad commit d024a19

File tree

1 file changed

+110
-44
lines changed

1 file changed

+110
-44
lines changed

silnlp/nmt/exp_summary.py

Lines changed: 110 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,56 @@
1111
from .config import get_mt_exp_dir
1212

1313
chap_num = 0
14+
trained_books = []
15+
target_book = ""
16+
all_books = []
17+
metrics = []
18+
key_word = ""
1419

1520

16-
def extract_data(filename, metrics, target_book, header_row=5) -> dict:
21+
def read_data(file_path, data, chapters):
1722
global chap_num
23+
global all_books
24+
global key_word
25+
26+
for lang_pair in os.listdir(file_path):
27+
lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)")
28+
if not lang_pattern.match(lang_pair):
29+
continue
30+
31+
data[lang_pair] = {}
32+
prefix = "+".join(all_books)
33+
pattern = re.compile(rf"^{re.escape(prefix)}_{key_word}_order_(\d+)_ch$")
34+
35+
for groups in os.listdir(os.path.join(file_path, lang_pair)):
36+
m = pattern.match(os.path.basename(groups))
37+
if m:
38+
folder_path = os.path.join(file_path, lang_pair, os.path.basename(groups))
39+
diff_pred_file = glob.glob(os.path.join(folder_path, "diff_predictions*"))
40+
if diff_pred_file:
41+
r = extract_data(diff_pred_file[0])
42+
data[lang_pair][int(m.group(1))] = r
43+
chapters.append(int(m.group(1)))
44+
if int(m.group(1)) > chap_num:
45+
chap_num = int(m.group(1))
46+
else:
47+
print(folder_path + " has no diff_predictions file.")
48+
49+
50+
def extract_data(filename, header_row=5) -> dict:
51+
global chap_num
52+
global metrics
53+
global target_book
1854

1955
metrics = [m.lower() for m in metrics]
2056
df = pd.read_excel(filename, header=header_row)
2157
df.columns = [col.strip().lower() for col in df.columns]
2258

2359
result = {}
60+
metric_warning = False
2461
for _, row in df.iterrows():
2562
vref = row["vref"]
26-
m = re.match(r"([A-Za-z]+)\s+(\d+)", str(vref))
63+
m = re.match(r"(\d?[A-Z]{2,3}) (\d+)", str(vref))
2764

2865
book_name, chap = m.groups()
2966
if book_name != target_book:
@@ -37,17 +74,22 @@ def extract_data(filename, metrics, target_book, header_row=5) -> dict:
3774
if metric in row:
3875
values.append(row[metric])
3976
else:
40-
print("Warning: {metric} is not calculated in {filename}")
77+
metric = True
4178
values.append(None)
4279

4380
result[int(chap)] = values
81+
82+
if metric_warning:
83+
print("Warning: {metric} is not calculated in {filename}")
84+
4485
return result
4586

4687

47-
def flatten_dict(data, metrics, chapters) -> list:
88+
def flatten_dict(data, chapters, baseline={}) -> list:
4889
global chap_num
90+
global metrics
4991

50-
res = []
92+
rows = []
5193
for lang_pair in data:
5294
for chap in range(1, chap_num + 1):
5395
row = [lang_pair, chap]
@@ -60,12 +102,16 @@ def flatten_dict(data, metrics, chapters) -> list:
60102
for m in range(len(metrics)):
61103
index_m = 3 + 1 + len(metrics) + chapters.index(res_chap) * (len(metrics) * 3 + 1) + m * 3
62104
row[index_m] = data[lang_pair][res_chap][chap][m]
63-
res.append(row)
64-
return res
105+
if len(baseline) > 0:
106+
for m in range(len(metrics)):
107+
row[3 + m] = baseline[lang_pair][chap][m]
108+
rows.append(row)
109+
return rows
65110

66111

67-
def create_xlsx(rows, metrics, chapters, output_path):
112+
def create_xlsx(rows, chapters, output_path):
68113
global chap_num
114+
global metrics
69115

70116
wb = Workbook()
71117
ws = wb.active
@@ -104,8 +150,9 @@ def create_xlsx(rows, metrics, chapters, output_path):
104150
ws.cell(row=2, column=col + i, value=sub_header)
105151

106152
col += len(sub_headers)
107-
for row in rows:
108-
ws.append(row)
153+
154+
for row in rows:
155+
ws.append(row)
109156

110157
for row_idx in [1, 2]:
111158
for col in range(1, ws.max_column + 1):
@@ -118,6 +165,12 @@ def create_xlsx(rows, metrics, chapters, output_path):
118165

119166
cur_lang_pair = 3
120167
for row_idx in range(3, ws.max_row + 1):
168+
if ws.cell(row=row_idx, column=4).value is not None:
169+
ws.cell(row=row_idx, column=3).value = (
170+
f"=RANK.EQ(D{row_idx}, INDEX(D:D, INT((ROW(D{row_idx})-3)/{chap_num})*{chap_num}+3):INDEX(D:D, \
171+
INT((ROW(D{row_idx})-3)/{chap_num})*{chap_num}+{chap_num}+2), 0)"
172+
)
173+
121174
start_col = 3 + len(metrics) + 1
122175
end_col = ws.max_column
123176

@@ -164,60 +217,73 @@ def create_xlsx(rows, metrics, chapters, output_path):
164217
wb.save(output_path)
165218

166219

220+
# Sample command:
221+
# python -m silnlp.nmt.exp_summary Catapult_Reloaded_Confidences
222+
# --trained-books MRK --target-book MAT --metrics chrf3 confidence --key-word conf --baseline Catapult_Reloaded/2nd_book/MRK
167223
def main() -> None:
168224
global chap_num
225+
global trained_books
226+
global target_book
227+
global all_books
228+
global metrics
229+
global key_word
169230

170-
# TODO: Add args for books, metrics, key word, baseline
171231
parser = argparse.ArgumentParser(description="Pull results")
172232
parser.add_argument("exp1", type=str, help="Experiment folder")
233+
parser.add_argument(
234+
"--trained-books", nargs="*", required=True, type=str.upper, help="Books that are trained in the exp"
235+
)
236+
parser.add_argument("--target-book", required=True, type=str.upper, help="Book that is going to be analyzed")
237+
parser.add_argument(
238+
"--metrics",
239+
nargs="*",
240+
metavar="metrics",
241+
default=["chrf3", "confidence"],
242+
type=str.lower,
243+
help="Metrics that will be analyzed with",
244+
)
245+
parser.add_argument("--key-word", type=str, default="conf", help="Key word in the filename for the exp group")
246+
parser.add_argument("--baseline", type=str, help="Baseline for the exp group")
173247
args = parser.parse_args()
174248

175-
trained_books = ["MRK"]
176-
target_book = ["MAT"]
177-
all_books = trained_books + target_book
178-
179-
metrics = ["chrf3", "confidence"]
180-
181-
key_word = "conf"
249+
trained_books = args.trained_books
250+
target_book = args.target_book
251+
all_books = trained_books + [target_book]
252+
metrics = args.metrics
253+
key_word = args.key_word
182254

183255
exp1_name = args.exp1
184256
exp1_dir = get_mt_exp_dir(exp1_name)
185257

258+
exp2_name = args.baseline
259+
exp2_dir = get_mt_exp_dir(exp2_name) if exp2_name else None
260+
186261
folder_name = "+".join(all_books)
187262
os.makedirs(os.path.join(exp1_dir, "a_result_folder"), exist_ok=True)
188263
output_path = os.path.join(exp1_dir, "a_result_folder", f"{folder_name}.xlsx")
189264

190265
data = {}
191266
chapters = []
267+
read_data(exp1_dir, data, chapters)
268+
chapters = sorted(set(chapters))
192269

193-
for lang_pair in os.listdir(exp1_dir):
194-
lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)")
195-
if not lang_pattern.match(lang_pair):
196-
continue
197-
198-
data[lang_pair] = {}
199-
prefix = "+".join(all_books)
200-
pattern = re.compile(rf"^{re.escape(prefix)}_{key_word}_order_(\d+)_ch$")
201-
202-
for groups in os.listdir(os.path.join(exp1_dir, lang_pair)):
203-
m = pattern.match(os.path.basename(groups))
204-
if m:
205-
base_name = "diff_predictions"
206-
folder_path = os.path.join(exp1_dir, lang_pair, os.path.basename(groups))
207-
diff_pred_file = glob.glob(os.path.join(folder_path, f"{base_name}*"))
208-
if diff_pred_file:
209-
r = extract_data(diff_pred_file[0], metrics, target_book[0])
210-
data[lang_pair][int(m.group(1))] = r
211-
chapters.append(int(m.group(1)))
212-
if int(m.group(1)) > chap_num:
213-
chap_num = int(m.group(1))
214-
else:
215-
print(os.path.basename(groups) + " has no diff_predictions file.")
270+
baseline_data = {}
271+
if exp2_dir:
272+
for lang_pair in os.listdir(exp2_dir):
273+
lang_pattern = re.compile(r"([\w-]+)\-([\w-]+)")
274+
if not lang_pattern.match(lang_pair):
275+
continue
276+
277+
baseline_path = os.path.join(exp2_dir, lang_pair)
278+
baseline_diff_pred = glob.glob(os.path.join(baseline_path, "diff_predictions*"))
279+
if baseline_diff_pred:
280+
baseline_data[lang_pair] = extract_data(baseline_diff_pred[0])
281+
else:
282+
print(f"Baseline experiment has no diff_predictions file in {baseline_path}")
216283

217-
chapters = sorted(set(chapters))
218284
print("Writing data...")
219-
rows = flatten_dict(data, metrics, chapters)
220-
create_xlsx(rows, metrics, chapters, output_path)
285+
rows = flatten_dict(data, chapters, baseline=baseline_data)
286+
create_xlsx(rows, chapters, output_path)
221287
print(f"Result is in {output_path}")
222288

223289

0 commit comments

Comments
 (0)