Skip to content

Commit

Permalink
Merge branch 'main' into add-reinforce
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony authored Sep 29, 2024
2 parents c72b285 + a6d6af0 commit afbaac9
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 75 deletions.
2 changes: 2 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ These can be set to any integer between `0` and `num_gpus`, and `num_gpus` must
# this should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"train_iters": 320000,
# alternatively, use train_epochs to automatically determine the number of training iterations
#"train_epochs": 1,
```
An example of some basic settings used to configure your model's architecture and number of training steps.

Expand Down
18 changes: 16 additions & 2 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ LR Scheduler Arguments
Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'.



- **lr_decay_iters**: int

Default = None

Number of iterations to decay learning rate over, If None defaults to --train-iters
Number of iterations to decay learning rate over. If None, defaults to
--train-iters or the equivalent inferred value from train_epochs.

- **lr_decay_fraction**: float

Default = None

Effective fraction of training over which to decay lr. Overrides lr_decay_iters.
Useful when specifying train_epochs.

- **min_lr**: float

Expand Down Expand Up @@ -1928,6 +1933,15 @@ Training Arguments
- **train_epochs**: int
Default = None
Number of epochs to run for training. Do not specify both train_epochs and train_iters.
Not currently compatible with data reweighing, pairwise datasets, and packing other than 'packed'
- **eval_iters**: int
Default = 100
Expand Down
149 changes: 118 additions & 31 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
import numpy as np
from typing import List, Tuple
from itertools import zip_longest
from itertools import zip_longest, cycle
from functools import partial

from megatron import mpu, print_rank_0
Expand Down Expand Up @@ -63,6 +63,7 @@ def build_the_dataset(
dataset_impl,
allow_chopped,
num_samples,
num_epochs,
seq_length,
seed,
skip_warmup,
Expand Down Expand Up @@ -142,6 +143,7 @@ def build_the_dataset(
documents,
indexed_dataset,
num_samples,
num_epochs,
seq_length,
seed,
pack_impl=pack_impl,
Expand Down Expand Up @@ -180,6 +182,7 @@ def build_train_valid_test_datasets(
allow_chopped,
splits_string,
train_valid_test_num_samples,
train_valid_test_epochs,
seq_length,
seed,
skip_warmup,
Expand Down Expand Up @@ -220,6 +223,7 @@ def build_dataset(index, name):
documents,
indexed_dataset,
train_valid_test_num_samples[index],
train_valid_test_epochs[index],
seq_length,
seed,
pack_impl=pack_impl,
Expand Down Expand Up @@ -269,12 +273,15 @@ def get_normalized_weights_and_num_samples(
weight_sum = sum(weights)
assert weight_sum > 0.0
weights = [weight / weight_sum for weight in weights]
# Add 0.5% (the 1.005 factor) so in case the blending dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
weighted_num_samples = []
for weight in weights:
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
if num_samples is not None:
# Add 0.5% (the 1.005 factor) so in case the blending dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
weighted_num_samples = []
for weight in weights:
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
else:
weighted_num_samples = [None for _ in weights]
return weights, weighted_num_samples


Expand All @@ -283,9 +290,9 @@ def build_weighted_datasets(
train_num_samples,
valid_num_samples,
test_num_samples,
train_weights,
valid_weights,
test_weights,
train_epochs,
valid_epochs,
test_epochs,
build_index_mappings=True,
):
# build individual datasets
Expand Down Expand Up @@ -368,6 +375,7 @@ def build_weighted_datasets(
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=train_num_samples[i],
num_epochs=train_epochs,
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
Expand All @@ -392,6 +400,7 @@ def build_weighted_datasets(
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=valid_num_samples[i],
num_epochs=valid_epochs,
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
Expand All @@ -416,6 +425,7 @@ def build_weighted_datasets(
pack_impl=neox_args.pack_impl,
allow_chopped=neox_args.allow_chopped,
num_samples=test_num_samples[i],
num_epochs=test_epochs,
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
Expand Down Expand Up @@ -470,9 +480,44 @@ def weights_by_num_docs(l: list, alpha=0.3):
return weights


def build_train_valid_test_data_iterators(neox_args):
def validate_train_epochs(neox_args):
"""Check for unsupported neox_args when using train_epochs instead of train_iters"""
if neox_args.train_epochs is None:
return

if neox_args.train_epochs and neox_args.train_iters:
raise ValueError(
"Cannot specify both train epochs and train iters simultaneously"
)

if neox_args.pack_impl != "packed":
raise ValueError(
"Packing implementations other than 'packed' are currently unsupported with train_epochs"
)

if neox_args.weight_by_num_documents:
raise ValueError(
"Weighting by number of documents is currently unsupported with train_epochs"
)

if neox_args.train_data_weights and (
not all(weight == 1.0 for weight in neox_args.train_data_weights)
):
raise ValueError(
"train_data_weights != None is currently unsupported with train_epochs"
)

if neox_args.dataset_impl != "gpt2":
raise ValueError(
"non gpt2 datasets are not currently unsupported with train_epochs"
)


def build_train_valid_test_data_loaders(neox_args):
"""XXX"""

validate_train_epochs(neox_args)

(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

print_rank_0("> building train, validation, and test datasets ...")
Expand Down Expand Up @@ -539,14 +584,21 @@ def build_train_valid_test_data_iterators(neox_args):
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
elif mpu.get_model_parallel_rank() == 0 and pipe_load:
# Number of train/valid/test samples.
train_iters = neox_args.train_iters
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
test_iters = neox_args.eval_iters
train_val_test_num_samples = [
train_iters * neox_args.train_batch_size,
eval_iters * neox_args.train_batch_size,
test_iters * neox_args.train_batch_size,
]
if neox_args.train_iters is not None:
train_iters = neox_args.train_iters
eval_iters = (
train_iters // neox_args.eval_interval + 1
) * neox_args.eval_iters
test_iters = neox_args.eval_iters
train_val_test_num_samples = [
train_iters * neox_args.train_batch_size,
eval_iters * neox_args.train_batch_size,
test_iters * neox_args.train_batch_size,
]
train_val_test_epochs = [None, None, None]
elif neox_args.train_epochs is not None:
train_val_test_num_samples = [None, None, None]
train_val_test_epochs = [1, 1, 1]

if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths):
# when individual train / valid / test data paths are provided
Expand All @@ -567,9 +619,9 @@ def build_train_valid_test_data_iterators(neox_args):
train_num_samples,
valid_num_samples,
test_num_samples,
train_weights,
valid_weights,
test_weights,
train_val_test_epochs[0],
train_val_test_epochs[1],
train_val_test_epochs[2],
build_index_mappings=not neox_args.weight_by_num_documents,
)

Expand Down Expand Up @@ -615,9 +667,9 @@ def build_train_valid_test_data_iterators(neox_args):
train_num_samples,
valid_num_samples,
test_num_samples,
train_weights,
valid_weights,
test_weights,
train_val_test_epochs[0],
train_val_test_epochs[1],
train_val_test_epochs[2],
)

if train_datasets:
Expand All @@ -635,6 +687,7 @@ def build_train_valid_test_data_iterators(neox_args):
data_impl=neox_args.data_impl,
splits_string=neox_args.split,
train_valid_test_num_samples=train_val_test_num_samples,
train_valid_test_epochs=train_val_test_epochs,
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
Expand All @@ -648,9 +701,15 @@ def build_train_valid_test_data_iterators(neox_args):
test_dataloader = make_data_loader(test_ds, neox_args=neox_args)

# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and neox_args.train_iters > 0
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
do_test = test_dataloader is not None and neox_args.eval_iters > 0
if neox_args.train_epochs:
do_train = train_dataloader is not None
do_valid = valid_dataloader is not None
do_test = test_dataloader is not None
else:
do_train = train_dataloader is not None and neox_args.train_iters > 0
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
do_test = test_dataloader is not None and neox_args.eval_iters > 0

# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
else:
Expand All @@ -670,6 +729,19 @@ def build_train_valid_test_data_iterators(neox_args):
neox_args.do_train = flags[0].item()
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()
data_loaders = {
"train": train_dataloader,
"valid": valid_dataloader,
"test": test_dataloader,
}
return data_loaders


def shift_and_wrap_data_loaders(neox_args, data_loaders, loop=True):
"""Shift start iteration and wrap data_loaders in iterators"""
train_dataloader = data_loaders["train"]
valid_dataloader = data_loaders["valid"]
test_dataloader = data_loaders["test"]

# Shift the start iterations.
if train_dataloader is not None:
Expand All @@ -695,19 +767,34 @@ def build_train_valid_test_data_iterators(neox_args):
)
)

def loop_iterator(data_loader):
while True:
for x in data_loader:
yield x
data_loader.start_iter = 0

# Build iterators.
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
if loop:
train_data_iterator = cycle(train_dataloader)
else:
train_data_iterator = iter(train_dataloader)
else:
train_data_iterator = None

if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader)
if loop:
valid_data_iterator = cycle(valid_dataloader)
else:
valid_data_iterator = iter(valid_dataloader)
else:
valid_data_iterator = None

if test_dataloader is not None:
test_data_iterator = iter(test_dataloader)
if loop:
test_data_iterator = cycle(test_dataloader)
else:
test_data_iterator = iter(test_dataloader)
else:
test_data_iterator = None

Expand Down
6 changes: 5 additions & 1 deletion megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
documents,
indexed_dataset,
num_samples,
num_epochs,
seq_length,
seed,
pack_impl="packed",
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
self.indexed_dataset.sizes,
self.label_dataset,
num_samples,
num_epochs,
seq_length,
seed,
self.pack_impl,
Expand Down Expand Up @@ -203,6 +205,7 @@ def _build_index_mappings(
sizes,
label_dataset,
num_samples,
num_epochs,
seq_length,
seed,
packing_impl,
Expand All @@ -217,7 +220,8 @@ def _build_index_mappings(
"""
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
if not num_epochs:
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)

Expand Down
6 changes: 5 additions & 1 deletion megatron/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ class DistributedBatchSampler(data.sampler.BatchSampler):
specifying True will result in the following samples for each gpu:
GPU0: [0,2,4,6] GPU1: [1,3,5,7]
specifying False will result in the following samples:
GPU0: [0,1,2,3] GPU1: [4,5,6,7]"""
GPU0: [0,1,2,3] GPU1: [4,5,6,7]
The `infinite_loop` parameter allows the sampler to yield batches indefinitely,
restarting from the beginning of the dataset when all samples have been iterated over.
"""

def __init__(
self,
Expand Down
Loading

0 comments on commit afbaac9

Please sign in to comment.