Skip to content

Commit 3dd0076

Browse files
committed
Adding support for transformers>=4.40.2 to avoid crash with mbpp
1 parent 84b96da commit 3dd0076

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

bigcode_eval/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,18 @@ def complete_code(
297297
**gen_kwargs,
298298
)
299299
else:
300-
generated_tokens = model.generate(
301-
input_ids=inputs,
302-
num_return_sequences=batch_size,
303-
**gen_kwargs,
304-
)
300+
# In transformers (>= 4.40.2), if the length of input_ids == max_length, a ValueError is thrown.
301+
# We want to ignore this error in order to reproduce old results with mbpp.
302+
try:
303+
generated_tokens = model.generate(
304+
input_ids=inputs,
305+
num_return_sequences=batch_size,
306+
**gen_kwargs,
307+
)
308+
except ValueError:
309+
# When the length of input_ids == max_length, the generation is the same as the input
310+
generated_tokens = inputs
311+
305312
# each task is generated batch_size times
306313
generated_tasks = batch["task_id"].repeat(batch_size)
307314
generated_tokens = accelerator.pad_across_processes(

0 commit comments

Comments
 (0)