Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Zero Bubble (ZB-1H) scheduling into FlagScale, which splits BW into B and W. #405

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions examples/aquila/conf/train/demo_zb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
system:
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 4
disable_bias_linear: True
use_flash_attn: True
# sequence_parallel: True
sequence_parallel: False
use_distributed_optimizer: True
transformer_impl: local
enable_zero_bubble: True
precision:
bf16: True
attention_softmax_in_fp32: True
accumulate_allreduce_grads_in_fp32: True
logging:
log_interval: 1
tensorboard_log_interval: 1
wandb_project: "aquila2"
wandb_exp_name: "test"
checkpoint:
save_interval: 1000


model:
num_layers: 256
hidden_size: 1024
num_attention_heads: 32
seq_length: 1024
max_position_embeddings: 2048
norm_epsilon: 1e-5
use_rotary_position_embeddings: true
no_position_embedding: true
swiglu: true
multiple_of: 256
# normalization: RMSNorm
normalization: LayerNorm
untie_embeddings_and_output_weights: true
init_method_std: 0.0165
attention_dropout: 0.0
hidden_dropout: 0.0
weight_decay: 0.1
clip_grad: 1.0
train_samples: 16
global_batch_size: 8
micro_batch_size: 1
# rampup_batch_size: [32, 32, 2000000]
seed: 42

optimizer:
lr: 2e-4
weight_decay: 0.01
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 1.5e-4
min_lr: 1.5e-5
lr_warmup_samples: 0
lr_decay_style: cosine

data:
data_path: /share/project/caozhou/adaptive_flash_ckpt/FlagScale/data/pile_wikipedia_demo # Please replace with your actual data path
split: 1
tokenizer:
tokenizer_type: AquilaTokenizerFS
vocab_file: ./examples/aquila/tokenizer/vocab.json
merge_file: ./examples/aquila/tokenizer/merges.txt
special_tokens_file: ./examples/aquila/tokenizer/special_tokens.txt
vocab_size: 100008
4 changes: 2 additions & 2 deletions flagscale/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad()

# Forward pass.
forward_backward_func = get_forward_backward_func()
forward_backward_func = get_forward_backward_func(enable_zero_bubble = get_args().enable_zero_bubble)
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
Expand Down Expand Up @@ -1994,7 +1994,7 @@ def evaluate(forward_step_func,
if verbose:
print_rank_0(f'Evaluating iter {iteration}/{eval_iters}')

forward_backward_func = get_forward_backward_func()
forward_backward_func = get_forward_backward_func(enable_zero_bubble = get_args().enable_zero_bubble)
# Don't care about timing during evaluation
config.timers = None
ft_integration.on_eval_step_start()
Expand Down
62 changes: 62 additions & 0 deletions flagscale/train/weight_grad_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# ref: https://github.com/sail-sg/zero-bubble-pipeline-parallelism/tree/zb-h1-quick-start
import queue

class WeightGradStore:

cache = []
weight_grad_queue = queue.Queue()
split_bw = True
enable_zero_bubble = False

@classmethod
def is_supported(cls):
return True
"""If not supported, fallback to original schedule."""
# args = get_args()
# if args.pipeline_model_parallel_size <= 1:
# return False
# if args.virtual_pipeline_model_parallel_size is not None:
# return False
# if args.overlap_grad_reduce:
# # the logic of overlapping grad reduce should be changed
# return False
# if args.transformer_impl == 'transformer_engine':
# # hard to capture weight gradient computation for transformer_engine
# return False
# if args.sequence_parallel:
# # not supported in this commit
# return False
# return True

@classmethod
def put(cls, total_input, grad_output, weight, func):
if not cls.split_bw or not cls.is_supported():
func(total_input, grad_output, weight.main_grad)
return
# Store the weight gradient computation of linear layers.
cls.cache.append((total_input, grad_output, weight, func))

@classmethod
def flush(cls):
if not cls.is_supported():
return
# Collect all stored computations during backward as a W.
cls.weight_grad_queue.put(cls.cache)
cls.cache = []

@classmethod
def pop(cls):
if not cls.is_supported():
return
# Execute a single W.
assert cls.weight_grad_queue.qsize() > 0
stored_grads = cls.weight_grad_queue.get()
for total_input, grad_output, weight, func in stored_grads:
func(total_input, grad_output, weight.main_grad)

@classmethod
def pop_all(cls):
# Execute all remaining W.
remaining_qsize = cls.weight_grad_queue.qsize()
for _ in range(remaining_qsize):
cls.pop()
Loading
Loading