@@ -242,6 +242,28 @@ def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bo
242242
243243@numba .njit
244244def simple_bucket (length ):
245+ """
246+ This bucket algorithm merely relies on the given number instead of based on
247+ slicing the known (min, max) range for several reasons:
248+ 1) Due to the use of the first-fit-decreasing (FFD) algorithm, the
249+ (min, max) sequence length of each rank will be much smaller than the
250+ (min, max) sequence length of the dataset. Bucketing on the
251+ (min, max) sequence length of the dataset is not practical
252+ 2) The (min, max) sequence length of a given rank is unknown until
253+ finishing 1 epoch since the packing is done on the fly
254+ 3) Due to the shuffling, the (min, max) sequence length of a given rank
255+ may vary between ranks. Once the (min, max) sequence length of a
256+ given rank changes, the bucketing also needs adjustment
257+
258+ This bucket algorithm is based on the most significant set bit of the input number.
259+ It first check what’s the most significant set bit, assuming it's bit "S",
260+ and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size.
261+ By default the range is divided into 16 buckets, so the bucket size will be
262+ 2 ** (S - 4)
263+ For example, 0b10001 will be padded to 0b10010.
264+ This approach can limit the overhead of bucketing (at most 1/16 of the input
265+ number) and also prevent recompilation due to a too small bucket size.
266+ """
245267 l = length
246268 msb = 0
247269 while l > 0 :
@@ -439,7 +461,7 @@ def reduce_sum_forward(
439461 output_attentions = output_attentions ,
440462 output_hidden_states = output_hidden_states ,
441463 return_dict = return_dict ,
442- ** deprecated_arguments if is_torch_hpu_available () else None ,
464+ ** _deprecated_arguments if is_torch_hpu_available () else None ,
443465 )
444466
445467 return_dict = isinstance (output , dict )
0 commit comments