Skip to content

Commit 57a6439

Browse files
committed
revert multi-pack sampler changes
1 parent a7d6f6f commit 57a6439

File tree

2 files changed

+27
-7
lines changed

2 files changed

+27
-7
lines changed

src/instructlab/training/multipack_sampler.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
import torch
3535
import torch.distributed as dist
3636

37-
from instructlab.training.utils import bucket
38-
3937

4038
def find_max_pack_len_with_padding(
4139
dataset,
@@ -213,11 +211,11 @@ def ffd_check_padding(a: np.ndarray, c: int, n: int):
213211
not_found = True
214212
for idx in range(n):
215213
# Calculate the new capacity if size is added to the bin
216-
new_capacity = bucket(max(bins_max_lengths[idx], size)) * (
214+
new_capacity = max(bins_max_lengths[idx], size) * (
217215
bins_num_samples[idx] + 1
218216
)
219217
if new_capacity <= c:
220-
bins_max_lengths[idx] = bucket(max(bins_max_lengths[idx], size))
218+
bins_max_lengths[idx] = max(bins_max_lengths[idx], size)
221219
bins_num_samples[idx] += 1
222220
not_found = False
223221
break
@@ -268,11 +266,11 @@ def ffd_with_result_padding(a: np.ndarray, c: int, start_index: int):
268266
add_new = True
269267
for idx in range(len(bins_max_lengths)):
270268
# Calculate the new capacity if size is added to the bin
271-
new_capacity = bucket(max(bins_max_lengths[idx], size)) * (
269+
new_capacity = max(bins_max_lengths[idx], size) * (
272270
bins_num_samples[idx] + 1
273271
)
274272
if new_capacity <= c:
275-
bins_max_lengths[idx] = bucket(max(bins_max_lengths[idx], size))
273+
bins_max_lengths[idx] = max(bins_max_lengths[idx], size)
276274
bins_num_samples[idx] += 1
277275
bins_result[idx].append(indices[a_id] + start_index)
278276
add_new = False

src/instructlab/training/utils.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,28 @@ def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bo
242242

243243
@numba.njit
244244
def simple_bucket(length):
245+
"""
246+
This bucket algorithm merely relies on the given number instead of based on
247+
slicing the known (min, max) range for several reasons:
248+
1) Due to the use of the first-fit-decreasing (FFD) algorithm, the
249+
(min, max) sequence length of each rank will be much smaller than the
250+
(min, max) sequence length of the dataset. Bucketing on the
251+
(min, max) sequence length of the dataset is not practical
252+
2) The (min, max) sequence length of a given rank is unknown until
253+
finishing 1 epoch since the packing is done on the fly
254+
3) Due to the shuffling, the (min, max) sequence length of a given rank
255+
may vary between ranks. Once the (min, max) sequence length of a
256+
given rank changes, the bucketing also needs adjustment
257+
258+
This bucket algorithm is based on the most significant set bit of the input number.
259+
It first check what’s the most significant set bit, assuming it's bit "S",
260+
and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size.
261+
By default the range is divided into 16 buckets, so the bucket size will be
262+
2 ** (S - 4)
263+
For example, 0b10001 will be padded to 0b10010.
264+
This approach can limit the overhead of bucketing (at most 1/16 of the input
265+
number) and also prevent recompilation due to a too small bucket size.
266+
"""
245267
l = length
246268
msb = 0
247269
while l > 0:
@@ -439,7 +461,7 @@ def reduce_sum_forward(
439461
output_attentions=output_attentions,
440462
output_hidden_states=output_hidden_states,
441463
return_dict=return_dict,
442-
**deprecated_arguments if is_torch_hpu_available() else None,
464+
**_deprecated_arguments if is_torch_hpu_available() else None,
443465
)
444466

445467
return_dict = isinstance(output, dict)

0 commit comments

Comments
 (0)