Skip to content

Commit 505c80e

Browse files
committed
Update utility for baseline (single exp) data extraction only
1 parent b37373f commit 505c80e

File tree

1 file changed

+40
-22
lines changed

1 file changed

+40
-22
lines changed

silnlp/nmt/exp_summary.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,33 @@ def flatten_dict(data: dict, chapters: list, baseline={}) -> list:
9090
global metrics
9191

9292
rows = []
93-
for lang_pair in data:
94-
for chap in range(1, chap_num + 1):
95-
row = [lang_pair, chap]
96-
row.extend([None, None, None] * len(metrics) * len(data[lang_pair]))
97-
row.extend([None] * len(chapters))
98-
row.extend([None] * (1 + len(metrics)))
99-
100-
for res_chap in data[lang_pair]:
101-
if chap in data[lang_pair][res_chap]:
93+
if len(data) > 0:
94+
for lang_pair in data:
95+
for chap in range(1, chap_num + 1):
96+
row = [lang_pair, chap]
97+
row.extend([None, None, None] * len(metrics) * len(data[lang_pair]))
98+
row.extend([None] * len(chapters))
99+
row.extend([None] * (1 + len(metrics)))
100+
101+
for res_chap in data[lang_pair]:
102+
if chap in data[lang_pair][res_chap]:
103+
for m in range(len(metrics)):
104+
index_m = 3 + 1 + len(metrics) + chapters.index(res_chap) * (len(metrics) * 3 + 1) + m * 3
105+
row[index_m] = data[lang_pair][res_chap][chap][m]
106+
if len(baseline) > 0:
102107
for m in range(len(metrics)):
103-
index_m = 3 + 1 + len(metrics) + chapters.index(res_chap) * (len(metrics) * 3 + 1) + m * 3
104-
row[index_m] = data[lang_pair][res_chap][chap][m]
105-
if len(baseline) > 0:
108+
row[3 + m] = baseline[lang_pair][chap][m]
109+
rows.append(row)
110+
else:
111+
for lang_pair in baseline:
112+
for chap in range(1, chap_num + 1):
113+
row = [lang_pair, chap]
114+
row.extend([None] * (1 + len(metrics)))
115+
106116
for m in range(len(metrics)):
107117
row[3 + m] = baseline[lang_pair][chap][m]
108-
rows.append(row)
118+
rows.append(row)
119+
109120
return rows
110121

111122

@@ -228,8 +239,10 @@ def main() -> None:
228239
global metrics
229240
global key_word
230241

231-
parser = argparse.ArgumentParser(description="Pull results")
232-
parser.add_argument("exp1", type=str, help="Experiment folder")
242+
parser = argparse.ArgumentParser(
243+
description="Pull results. At least one --exp or --baseline needs to be specified."
244+
)
245+
parser.add_argument("--exp", type=str, help="Experiment folder with progression results")
233246
parser.add_argument(
234247
"--trained-books", nargs="*", required=True, type=str.upper, help="Books that are trained in the exp"
235248
)
@@ -243,29 +256,34 @@ def main() -> None:
243256
help="Metrics that will be analyzed with",
244257
)
245258
parser.add_argument("--key-word", type=str, default="conf", help="Key word in the filename for the exp group")
246-
parser.add_argument("--baseline", type=str, help="Baseline for the exp group")
259+
parser.add_argument("--baseline", type=str, help="Baseline or non-progression result for the exp group")
247260
args = parser.parse_args()
248261

262+
if not (args.exp or args.baseline):
263+
parser.error("At least one --exp or --baseline needs to be specified.")
264+
249265
trained_books = args.trained_books
250266
target_book = args.target_book
251267
all_books = trained_books + [target_book]
252268
metrics = args.metrics
253269
key_word = args.key_word
254270

255-
exp1_name = args.exp1
256-
exp1_dir = get_mt_exp_dir(exp1_name)
271+
exp1_name = args.exp
272+
exp1_dir = get_mt_exp_dir(exp1_name) if exp1_name else None
257273

258274
exp2_name = args.baseline
259275
exp2_dir = get_mt_exp_dir(exp2_name) if exp2_name else None
260276

261277
folder_name = "+".join(all_books)
262-
os.makedirs(os.path.join(exp1_dir, "a_result_folder"), exist_ok=True)
263-
output_path = os.path.join(exp1_dir, "a_result_folder", f"{folder_name}.xlsx")
278+
result_dir = exp1_dir if exp1_dir else exp2_dir
279+
os.makedirs(os.path.join(result_dir, "a_result_folder"), exist_ok=True)
280+
output_path = os.path.join(result_dir, "a_result_folder", f"{folder_name}.xlsx")
264281

265282
data = {}
266283
chapters = []
267-
read_data(exp1_dir, data, chapters)
268-
chapters = sorted(set(chapters))
284+
if exp1_dir:
285+
read_data(exp1_dir, data, chapters)
286+
chapters = sorted(set(chapters))
269287

270288
baseline_data = {}
271289
if exp2_dir:

0 commit comments

Comments
 (0)