Skip to content

Commit

Permalink
support smart batching in torchacc
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Dec 25, 2024
1 parent 75dd4af commit 3f30b98
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchacc/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from smart_batching import SmartBatchingSampler, flatten_mapfn_for_swift
133 changes: 133 additions & 0 deletions torchacc/data/smart_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import binpacking
import torch
import numpy as np
from typing import List, Dict, Any

def flatten_mapfn_for_swift(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Data collator used for padding free approach. Does the following:
- concatate the entire mini batch into single long sequence [1, total_tokens]
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
Args:
batch(`List[Dict[str, Any]]`): The input data in batch
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
will be padded to the `longest`
"""
packed_data = {}
position_id_lengths = [len(item['input_ids']) for item in batch]
packed_data['input_ids'] = np.concatenate([item['input_ids'] for item in batch])
packed_data['labels'] = np.concatenate([item['labels'] for item in batch])
packed_data['position_ids'] = np.concatenate([list(range(pil)) for pil in position_id_lengths])
return packed_data


class SmartBatchingSampler:
"""Smart batching sampler for Megatron-LM.
Args:
dataset: A list of sequence lengths, each length is the length of a sequence.
total_samples: Total number of samples to be consumed.
micro_batch_size: Micro batch size.
data_parallel_rank: Data parallel rank.
data_parallel_size: Data parallel size.
consumed_samples: Consumed samples, mainly usedfor continue train from the last checkpoint.
"""
def __init__(self,
dataset, # Lengths of sequences,
dataset_type, # Workload type
total_samples, # Total number of samples
micro_batch_size, # Micro batch size
data_parallel_rank, # Data parallel rank
data_parallel_size, # Data parallel size
consumed_samples = 0, # Consumed samples, mainly used for continue train from the last checkpoint
balance_strategy='micro-batch', # Balance strategy
):
# Keep a copy of input params for later use.
self.dataset = dataset
self.total_samples = total_samples
self.dataset_type = dataset_type
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.balance_strategy = balance_strategy
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
assert self.balance_strategy in ['micro-batch', "none"], \
'invalid balance_strategy: {}, only {} and {} are supported'.format(self.balance_strategy, 'micro-batch', 'none')
assert self.dataset_type in ['swift'] \
'invalid dataset_type: {}, only {} are supported'.format(self.dataset_type, 'swift')
def __len__(self):
return self.total_samples // self.data_parallel_size

def binpack_to_constant_bin_number_with_max_weight_limit(self, packages, max_length, bin_num):
"""A bin-packing algorithm to pack the packages into bins with constant bin number and max length limit""" \
"""Returns None if the max weight limit cannot be satisfied"""
packages.sort(key=lambda item: item[1], reverse=True)
bins = [[] for _ in range(bin_num)]
bin_sum = [0] * bin_num
package_idx = 0
bin_idx_list = list(range(0, bin_num, 1)) + list(range(bin_num - 1, -1, -1))
while True:
processed = False
for bin_idx in bin_idx_list:
if bin_sum[bin_idx] + packages[package_idx][1] <= max_length:
bins[bin_idx].append(packages[package_idx])
bin_sum[bin_idx] += packages[package_idx][1]
package_idx += 1
processed = True
if package_idx == len(packages):
break
if not processed or package_idx == len(packages):
break
return bins if package_idx == len(packages) else None

def get_sequence_length(self, idx):
if self.dataset_type == "swift":
return self.dataset[idx]['input_ids'].shape[0]
def construct_balanced_batch(self, batch):
# No balancing, just flatten the batch
if self.balance_strategy == "none":
return batch[self.data_parallel_rank::self.data_parallel_size]
# Micro-batch level balancing
if self.balance_strategy == "micro-batch":
packages = {}
for idx, sample_idx in enumerate(batch):
packages[idx] = self.get_sequence_length(sample_idx)
bins = binpacking.to_constant_bin_number(packages, self.data_parallel_size)
current_batch = []
for idx in bins[self.data_parallel_rank].keys():
current_batch.append(batch[idx])
return current_batch

def __iter__(self):
# Sanity checks:
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

# Continue train from where it left
g = torch.Generator()
g.manual_seed(self.epoch)
shuffle_samples = torch.randperm(self.total_samples, generator=g).tolist()
shuffle_samples = shuffle_samples[current_epoch_samples: ]
# Get one batch
batch = []
for idx in shuffle_samples:
batch.append(idx)
# Balance micro-batch across data parallel ranks
if (self.balance_strategy == "micro-batch" or self.balance_strategy == "none") and \
len(batch) == self.micro_batch_times_data_parallel_size:
self.consumed_samples += self.micro_batch_size
yield self.construct_balanced_batch(batch)
batch.clear()

0 comments on commit 3f30b98

Please sign in to comment.