Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] support smart flatten collate in torchacc #47

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchacc/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .smart_batching import SmartBatchingSampler, flatten_mapfn_for_swift
123 changes: 123 additions & 0 deletions torchacc/data/smart_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Any, Dict, List

import binpacking
import numpy as np
import torch


def flatten_mapfn_for_swift(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Data collator used for padding free approach. Does the following:
- concatate the entire mini batch into single long sequence [1, total_tokens]
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
Args:
batch(`List[Dict[str, Any]]`): The input data in batch
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
will be padded to the `longest`
"""
packed_data = {}
position_id_lengths = [len(item['input_ids']) for item in batch]
packed_data['input_ids'] = np.concatenate(
[item['input_ids'] for item in batch])
packed_data['labels'] = np.concatenate([item['labels'] for item in batch])
packed_data['position_ids'] = np.concatenate(
[list(range(pil)) for pil in position_id_lengths])
return packed_data


class SmartBatchingSampler:
"""Smart batching sampler for Megatron-LM.
Args:
dataset: A list of sequence lengths, each length is the length of a sequence.
total_samples: Total number of samples to be consumed.
micro_batch_size: Micro batch size.
data_parallel_rank: Data parallel rank.
data_parallel_size: Data parallel size.
consumed_samples: Consumed samples, mainly usedfor continue train from the last checkpoint.
"""

def __init__(
self,
dataset, # Lengths of sequences,
dataset_type, # Workload type
total_samples, # Total number of samples
micro_batch_size, # Micro batch size
data_parallel_rank, # Data parallel rank
data_parallel_size, # Data parallel size
consumed_samples=0, # Consumed samples, mainly used for continue train from the last checkpoint
balance_strategy='micro-batch', # Balance strategy
seed=42, # Random seed. We have to keep the seed same for all the processes.
):
# Keep a copy of input params for later use.
self.dataset = dataset
self.total_samples = total_samples
self.dataset_type = dataset_type
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.balance_strategy = balance_strategy
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
self.seed = seed
self.generator = torch.Generator()

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
assert self.balance_strategy in ['micro-batch', "none"], \
'invalid balance_strategy: {}, only {} and {} are supported'.format(self.balance_strategy, 'micro-batch', 'none')
assert self.dataset_type in ['swift'], \
'invalid dataset_type: {}, only {} are supported'.format(self.dataset_type, 'swift')

def __len__(self):
return self.total_samples // self.micro_batch_times_data_parallel_size

def get_sequence_length(self, idx):
if self.dataset_type == "swift":
return len(self.dataset[idx]['input_ids'])

def construct_balanced_batch(self, batch):
# No balancing, just flatten the batch
if self.balance_strategy == "none":
return batch[self.data_parallel_rank::self.data_parallel_size]
# Micro-batch level balancing
if self.balance_strategy == "micro-batch":
packages = {}
for idx, sample_idx in enumerate(batch):
packages[idx] = self.get_sequence_length(sample_idx)
bins = binpacking.to_constant_bin_number(packages,
self.data_parallel_size)
current_batch = []
for idx in bins[self.data_parallel_rank].keys():
current_batch.append(batch[idx])
return current_batch

def __iter__(self):
# Sanity checks:
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

# Continue train from where it left
self.generator.manual_seed(self.seed + self.epoch)
shuffle_samples = torch.randperm(
self.total_samples, generator=self.generator).tolist()
shuffle_samples = shuffle_samples[current_epoch_samples:]
# Get one batch
batch = []
for idx in shuffle_samples:
batch.append(idx)
# Balance micro-batch across data parallel ranks
if (self.balance_strategy == "micro-batch" or self.balance_strategy == "none") and \
len(batch) == self.micro_batch_times_data_parallel_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield self.construct_balanced_batch(batch)
batch.clear()
6 changes: 3 additions & 3 deletions torchacc/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .flash_attn import (flash_attn_varlen_qkvpacked_xla, flash_attn_varlen_xla,
flash_attn_xla, spmd_flash_attn_varlen_xla,
flash_attn_varlen_position_ids_xla)
from .flash_attn import (flash_attn_varlen_position_ids_xla,
flash_attn_varlen_qkvpacked_xla, flash_attn_varlen_xla,
flash_attn_xla, spmd_flash_attn_varlen_xla)
from .liger import (apply_liger_kernel, apply_liger_kernel_to_llama,
apply_liger_kernel_to_qwen2)
from .scaled_dot_product_attention import scaled_dot_product_attention
2 changes: 2 additions & 0 deletions torchacc/utils/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def patch_fa():
"4.43.0") and version.parse(version_ts) <= version.parse(
"4.46.3"):
from typing import Optional

import transformers.modeling_flash_attention_utils as modeling_flash_attention_utils

def _flash_attention_forward(
Expand Down Expand Up @@ -253,6 +254,7 @@ def patch_qwen(use_flash_attn):
and replace flash_attn with the interface in torchacc. This requires transformers>=4.41.0.
'''
import inspect

import transformers
from packaging import version

Expand Down
Loading