@@ -90,22 +90,33 @@ def flatten_dict(data: dict, chapters: list, baseline={}) -> list:
9090 global metrics
9191
9292 rows = []
93- for lang_pair in data :
94- for chap in range (1 , chap_num + 1 ):
95- row = [lang_pair , chap ]
96- row .extend ([None , None , None ] * len (metrics ) * len (data [lang_pair ]))
97- row .extend ([None ] * len (chapters ))
98- row .extend ([None ] * (1 + len (metrics )))
99-
100- for res_chap in data [lang_pair ]:
101- if chap in data [lang_pair ][res_chap ]:
93+ if len (data ) > 0 :
94+ for lang_pair in data :
95+ for chap in range (1 , chap_num + 1 ):
96+ row = [lang_pair , chap ]
97+ row .extend ([None , None , None ] * len (metrics ) * len (data [lang_pair ]))
98+ row .extend ([None ] * len (chapters ))
99+ row .extend ([None ] * (1 + len (metrics )))
100+
101+ for res_chap in data [lang_pair ]:
102+ if chap in data [lang_pair ][res_chap ]:
103+ for m in range (len (metrics )):
104+ index_m = 3 + 1 + len (metrics ) + chapters .index (res_chap ) * (len (metrics ) * 3 + 1 ) + m * 3
105+ row [index_m ] = data [lang_pair ][res_chap ][chap ][m ]
106+ if len (baseline ) > 0 :
102107 for m in range (len (metrics )):
103- index_m = 3 + 1 + len (metrics ) + chapters .index (res_chap ) * (len (metrics ) * 3 + 1 ) + m * 3
104- row [index_m ] = data [lang_pair ][res_chap ][chap ][m ]
105- if len (baseline ) > 0 :
108+ row [3 + m ] = baseline [lang_pair ][chap ][m ]
109+ rows .append (row )
110+ else :
111+ for lang_pair in baseline :
112+ for chap in range (1 , chap_num + 1 ):
113+ row = [lang_pair , chap ]
114+ row .extend ([None ] * (1 + len (metrics )))
115+
106116 for m in range (len (metrics )):
107117 row [3 + m ] = baseline [lang_pair ][chap ][m ]
108- rows .append (row )
118+ rows .append (row )
119+
109120 return rows
110121
111122
@@ -228,8 +239,10 @@ def main() -> None:
228239 global metrics
229240 global key_word
230241
231- parser = argparse .ArgumentParser (description = "Pull results" )
232- parser .add_argument ("exp1" , type = str , help = "Experiment folder" )
242+ parser = argparse .ArgumentParser (
243+ description = "Pull results. At least one --exp or --baseline needs to be specified."
244+ )
245+ parser .add_argument ("--exp" , type = str , help = "Experiment folder with progression results" )
233246 parser .add_argument (
234247 "--trained-books" , nargs = "*" , required = True , type = str .upper , help = "Books that are trained in the exp"
235248 )
@@ -243,29 +256,34 @@ def main() -> None:
243256 help = "Metrics that will be analyzed with" ,
244257 )
245258 parser .add_argument ("--key-word" , type = str , default = "conf" , help = "Key word in the filename for the exp group" )
246- parser .add_argument ("--baseline" , type = str , help = "Baseline for the exp group" )
259+ parser .add_argument ("--baseline" , type = str , help = "Baseline or non-progression result for the exp group" )
247260 args = parser .parse_args ()
248261
262+ if not (args .exp or args .baseline ):
263+ parser .error ("At least one --exp or --baseline needs to be specified." )
264+
249265 trained_books = args .trained_books
250266 target_book = args .target_book
251267 all_books = trained_books + [target_book ]
252268 metrics = args .metrics
253269 key_word = args .key_word
254270
255- exp1_name = args .exp1
256- exp1_dir = get_mt_exp_dir (exp1_name )
271+ exp1_name = args .exp
272+ exp1_dir = get_mt_exp_dir (exp1_name ) if exp1_name else None
257273
258274 exp2_name = args .baseline
259275 exp2_dir = get_mt_exp_dir (exp2_name ) if exp2_name else None
260276
261277 folder_name = "+" .join (all_books )
262- os .makedirs (os .path .join (exp1_dir , "a_result_folder" ), exist_ok = True )
263- output_path = os .path .join (exp1_dir , "a_result_folder" , f"{ folder_name } .xlsx" )
278+ result_dir = exp1_dir if exp1_dir else exp2_dir
279+ os .makedirs (os .path .join (result_dir , "a_result_folder" ), exist_ok = True )
280+ output_path = os .path .join (result_dir , "a_result_folder" , f"{ folder_name } .xlsx" )
264281
265282 data = {}
266283 chapters = []
267- read_data (exp1_dir , data , chapters )
268- chapters = sorted (set (chapters ))
284+ if exp1_dir :
285+ read_data (exp1_dir , data , chapters )
286+ chapters = sorted (set (chapters ))
269287
270288 baseline_data = {}
271289 if exp2_dir :
0 commit comments