File tree 1 file changed +12
-5
lines changed 1 file changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -297,11 +297,18 @@ def complete_code(
297
297
** gen_kwargs ,
298
298
)
299
299
else :
300
- generated_tokens = model .generate (
301
- input_ids = inputs ,
302
- num_return_sequences = batch_size ,
303
- ** gen_kwargs ,
304
- )
300
+ # In transformers (>= 4.40.2), if the length of input_ids == max_length, a ValueError is thrown.
301
+ # We want to ignore this error in order to reproduce old results with mbpp.
302
+ try :
303
+ generated_tokens = model .generate (
304
+ input_ids = inputs ,
305
+ num_return_sequences = batch_size ,
306
+ ** gen_kwargs ,
307
+ )
308
+ except ValueError :
309
+ # When the length of input_ids == max_length, the generation is the same as the input
310
+ generated_tokens = inputs
311
+
305
312
# each task is generated batch_size times
306
313
generated_tasks = batch ["task_id" ].repeat (batch_size )
307
314
generated_tokens = accelerator .pad_across_processes (
You can’t perform that action at this time.
0 commit comments