@@ -59,14 +59,14 @@ def generate_text(self, task_name, intermediate_generations=None):
5959 solutions = [[ref ] for ref in references ]
6060 return solutions , references
6161
62- generations = [] # list[list[str | None] | None]
62+ curr_generations = [] # list[list[str | None] | None]
6363 if intermediate_generations :
64- generations = [gen for gen in intermediate_generations if gen ]
65- n_tasks -= len (generations )
64+ curr_generations = [gen for gen in intermediate_generations if gen ]
65+ n_tasks -= len (curr_generations )
6666 intermediate_save_generations_path = f"{ os .path .splitext (self .args .save_generations_path )[0 ]} _{ task_name } _intermediate.json"
67- curr_sample_idx = len (generations )
67+ curr_sample_idx = len (curr_generations )
6868
69- new_generations = parallel_generations (
69+ generations = parallel_generations (
7070 task ,
7171 dataset ,
7272 self .accelerator ,
@@ -76,10 +76,9 @@ def generate_text(self, task_name, intermediate_generations=None):
7676 args = self .args ,
7777 curr_sample_idx = curr_sample_idx , # curr_sample_idx will added to limit_start to fix indexing
7878 save_every_k_tasks = self .args .save_every_k_tasks ,
79- intermediate_generations = generations ,
79+ intermediate_generations = curr_generations ,
8080 intermediate_save_generations_path = intermediate_save_generations_path ,
8181 )
82- generations .extend (new_generations )
8382
8483 if len (generations [0 ]) > self .args .n_samples :
8584 generations = [l [: self .args .n_samples ] for l in generations ]
0 commit comments