Skip to content

Commit 334efb7

Browse files
authored
Merge pull request #244 from meher-m/transformers_fix
Adding support for transformers>=4.40.2 to avoid crash with mbpp
2 parents 4659ecd + cc9033c commit 334efb7

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

bigcode_eval/utils.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,22 @@ def complete_code(
297297
**gen_kwargs,
298298
)
299299
else:
300-
generated_tokens = model.generate(
301-
input_ids=inputs,
302-
num_return_sequences=batch_size,
303-
**gen_kwargs,
304-
)
300+
# In transformers (>= 4.40.2), if the length of input_ids == max_length, a ValueError is thrown.
301+
# We want to ignore this error in order to reproduce old results with mbpp.
302+
try:
303+
generated_tokens = model.generate(
304+
input_ids=inputs,
305+
num_return_sequences=batch_size,
306+
**gen_kwargs,
307+
)
308+
except ValueError as e:
309+
# When the length of input_ids == max_length, the generation is the same as the input
310+
if str(e).startswith(f"Input length of input_ids is {inputs.shape[1]}, but `max_length` is set to {gen_kwargs['max_length']}"):
311+
warnings.warn(f"An error with the following message was thrown: {e}. Returning the input as the generation, for higher scores consider using a larger `max_length`")
312+
generated_tokens = inputs
313+
else:
314+
raise e
315+
305316
# each task is generated batch_size times
306317
generated_tasks = batch["task_id"].repeat(batch_size)
307318
generated_tokens = accelerator.pad_across_processes(

0 commit comments

Comments
 (0)