Skip to content

Commit 8cffbfd

Browse files
committed
fix duplication issues
1 parent 0754793 commit 8cffbfd

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

bigcode_eval/evaluator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ def generate_text(self, task_name, intermediate_generations=None):
5959
solutions = [[ref] for ref in references]
6060
return solutions, references
6161

62-
generations = [] # list[list[str | None] | None]
62+
curr_generations = [] # list[list[str | None] | None]
6363
if intermediate_generations:
64-
generations = [gen for gen in intermediate_generations if gen]
65-
n_tasks -= len(generations)
64+
curr_generations = [gen for gen in intermediate_generations if gen]
65+
n_tasks -= len(curr_generations)
6666
intermediate_save_generations_path = f"{os.path.splitext(self.args.save_generations_path)[0]}_{task_name}_intermediate.json"
67-
curr_sample_idx = len(generations)
67+
curr_sample_idx = len(curr_generations)
6868

69-
new_generations = parallel_generations(
69+
generations = parallel_generations(
7070
task,
7171
dataset,
7272
self.accelerator,
@@ -76,10 +76,9 @@ def generate_text(self, task_name, intermediate_generations=None):
7676
args=self.args,
7777
curr_sample_idx=curr_sample_idx, # curr_sample_idx will added to limit_start to fix indexing
7878
save_every_k_tasks=self.args.save_every_k_tasks,
79-
intermediate_generations=generations,
79+
intermediate_generations=curr_generations,
8080
intermediate_save_generations_path=intermediate_save_generations_path,
8181
)
82-
generations.extend(new_generations)
8382

8483
if len(generations[0]) > self.args.n_samples:
8584
generations = [l[: self.args.n_samples] for l in generations]

bigcode_eval/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import warnings
55
from collections import defaultdict
6+
from copy import deepcopy
67
from typing import List, Optional
78

89
import torch
@@ -334,8 +335,9 @@ def complete_code(
334335
gen_token_dict,
335336
)
336337
with open(intermediate_save_generations_path, "w") as fp:
337-
intermediate_generations.extend(code_gens)
338-
json.dump(intermediate_generations, fp)
338+
intermediate_save_generations = deepcopy(intermediate_generations)
339+
intermediate_save_generations.extend(code_gens)
340+
json.dump(intermediate_save_generations, fp)
339341
print(
340342
f"intermediate generations were saved at {intermediate_save_generations_path}"
341343
)
@@ -353,7 +355,7 @@ def complete_code(
353355
gen_token_dict,
354356
)
355357

356-
return code_gens
358+
return intermediate_generations.extend(code_gens)
357359

358360

359361
def update_code_gens(

0 commit comments

Comments
 (0)